import cv2
import torch
import torchaudio
import numpy as np
from sadaco.apis.explain.hookman import FGHandler
from sadaco.utils.config_parser import parse_config_obj
from sadaco.apis.models import build_model
from sadaco.apis.explain.visualize import spec_display
from sadaco.apis.traintest.common import load_input
[docs]class demo_helper:
def __init__(self, master_cfg, model_cfg):
self.master_cfg = master_cfg
self.model_cfg = model_cfg
model = build_model(model_cfg)
model = model.cuda()
model.eval()
checkpoint = torch.load(model_cfg.model_checkpoint)
try:
model.load_state_dict(checkpoint['state_dict'])
except RuntimeError:
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint['state_dict'])
model = model.module
layers = ['layer4.2.conv3', 'layer3.2.conv3', 'layer2.2.conv3', 'layer1.2.conv3',
'layer4.1.conv3', 'layer3.1.conv3']
handler = FGHandler(model, layers)
model.handler = handler
self.model = model
[docs] def do_inference(self, input_path, return_raw=False):
inputs = load_input(input_path)
outputs = self.model(inputs.unsqueeze(0))
outputs = torch.softmax(outputs, dim=1)[0]
text = f"Normal {outputs[0]*100:.2f} %, Wheeze {outputs[1]*100:.2f} %\nCrackle {outputs[2]*100:.2f} %, Both {outputs[3]*100:.2f} %"
if return_raw:
return inputs, outputs
else:
return text
[docs] def do_explanation(self, input_path, method, cls):
if method == 0:
inputs, outputs = self.do_inference(input_path, return_raw=True)
outputs[cls].backward()
feature = self.model.handler.get_all_features('upscale', 'upscale')
gradient = self.model.handler.get_all_grads('upscale', 'upscale')
weight = np.mean(gradient.cpu().numpy(), axis=(2, 3))
cam = feature.detach().cpu().numpy() * weight[:, :, np.newaxis, np.newaxis]
cam = np.sum(cam, axis=1)
cam = np.maximum(cam, 0)
cam -= np.min(cam, axis=(1,2), keepdims=True)
cam /= np.max(cam, axis=(1,2), keepdims=True)
cam = np.array([cv2.resize(c, tuple(inputs.shape[-2:])) for c in cam])
# cam = (1000*cam).astype(np.int32)
# hist, bins = np.histogram(cam.flatten(), 1001, [0, 1001])
# cdf = hist.cumsum()
# cdf_m = np.ma.masked_equal(cdf,0)
# cdf_m = (cdf_m - cdf_m.min())*1000/(cdf_m.max()-cdf_m.min())
# cdf = np.ma.filled(cdf_m,0)
# cam2 = cdf[cam]
cam2 = inputs.cpu().numpy() * cam
arr = spec_display(cam2[0].astype(np.float32), return_array=True)
return arr
elif method == 1:
inputs = load_input(input_path)
IG = torch.zeros_like(inputs)
baseline = torch.zeros_like(inputs)
steps=self.master_cfg.explainer.ig.ig_steps
for i in range(steps+1):
scaled_inputs = torch.nn.Parameter(baseline + (inputs - baseline) * float(i) / steps)
outputs = self.model(scaled_inputs.unsqueeze(0))
self.model.zero_grad()
loss_grads = torch.autograd.grad(outputs[0, cls], scaled_inputs)
IG += loss_grads[0]/steps
IG = (inputs - baseline) * IG
IG = torch.relu(IG)
IG -= torch.min(torch.nn.Flatten()(IG).unsqueeze(1), dim=-1)[0][:,:,None]
IG /= torch.clamp(torch.max(torch.nn.Flatten()(IG).unsqueeze(1), dim=-1)[0][:,:,None],
min=1e-16)
IG = IG.detach().cpu().numpy()
# IG = (1000*IG).astype(np.int32)
# hist, bins = np.histogram(IG.flatten(), 1001, [0, 1001])
# cdf = hist.cumsum()
# cdf_m = np.ma.masked_equal(cdf,0)
# cdf_m = (cdf_m - cdf_m.min())*1000/(cdf_m.max()-cdf_m.min())
# cdf = np.ma.filled(cdf_m,0)
# IG2 = cdf[IG]
IG2 = inputs.cpu().numpy() * IG
arr = spec_display(IG2[0].astype(np.float32), return_array=True)
return arr
else:
raise ValueError