Source code for sadaco.apis.traintest.demo

import cv2
import torch
import torchaudio
import numpy as np
from sadaco.apis.explain.hookman import FGHandler
from sadaco.utils.config_parser import parse_config_obj
from sadaco.apis.models import build_model
from sadaco.apis.explain.visualize import spec_display
from sadaco.apis.traintest.common import load_input

[docs]class demo_helper: def __init__(self, master_cfg, model_cfg): self.master_cfg = master_cfg self.model_cfg = model_cfg model = build_model(model_cfg) model = model.cuda() model.eval() checkpoint = torch.load(model_cfg.model_checkpoint) try: model.load_state_dict(checkpoint['state_dict']) except RuntimeError: model = torch.nn.DataParallel(model) model.load_state_dict(checkpoint['state_dict']) model = model.module layers = ['layer4.2.conv3', 'layer3.2.conv3', 'layer2.2.conv3', 'layer1.2.conv3', 'layer4.1.conv3', 'layer3.1.conv3'] handler = FGHandler(model, layers) model.handler = handler self.model = model
[docs] def do_inference(self, input_path, return_raw=False): inputs = load_input(input_path) outputs = self.model(inputs.unsqueeze(0)) outputs = torch.softmax(outputs, dim=1)[0] text = f"Normal {outputs[0]*100:.2f} %, Wheeze {outputs[1]*100:.2f} %\nCrackle {outputs[2]*100:.2f} %, Both {outputs[3]*100:.2f} %" if return_raw: return inputs, outputs else: return text
[docs] def do_explanation(self, input_path, method, cls): if method == 0: inputs, outputs = self.do_inference(input_path, return_raw=True) outputs[cls].backward() feature = self.model.handler.get_all_features('upscale', 'upscale') gradient = self.model.handler.get_all_grads('upscale', 'upscale') weight = np.mean(gradient.cpu().numpy(), axis=(2, 3)) cam = feature.detach().cpu().numpy() * weight[:, :, np.newaxis, np.newaxis] cam = np.sum(cam, axis=1) cam = np.maximum(cam, 0) cam -= np.min(cam, axis=(1,2), keepdims=True) cam /= np.max(cam, axis=(1,2), keepdims=True) cam = np.array([cv2.resize(c, tuple(inputs.shape[-2:])) for c in cam]) # cam = (1000*cam).astype(np.int32) # hist, bins = np.histogram(cam.flatten(), 1001, [0, 1001]) # cdf = hist.cumsum() # cdf_m = np.ma.masked_equal(cdf,0) # cdf_m = (cdf_m - cdf_m.min())*1000/(cdf_m.max()-cdf_m.min()) # cdf = np.ma.filled(cdf_m,0) # cam2 = cdf[cam] cam2 = inputs.cpu().numpy() * cam arr = spec_display(cam2[0].astype(np.float32), return_array=True) return arr elif method == 1: inputs = load_input(input_path) IG = torch.zeros_like(inputs) baseline = torch.zeros_like(inputs) steps=self.master_cfg.explainer.ig.ig_steps for i in range(steps+1): scaled_inputs = torch.nn.Parameter(baseline + (inputs - baseline) * float(i) / steps) outputs = self.model(scaled_inputs.unsqueeze(0)) self.model.zero_grad() loss_grads = torch.autograd.grad(outputs[0, cls], scaled_inputs) IG += loss_grads[0]/steps IG = (inputs - baseline) * IG IG = torch.relu(IG) IG -= torch.min(torch.nn.Flatten()(IG).unsqueeze(1), dim=-1)[0][:,:,None] IG /= torch.clamp(torch.max(torch.nn.Flatten()(IG).unsqueeze(1), dim=-1)[0][:,:,None], min=1e-16) IG = IG.detach().cpu().numpy() # IG = (1000*IG).astype(np.int32) # hist, bins = np.histogram(IG.flatten(), 1001, [0, 1001]) # cdf = hist.cumsum() # cdf_m = np.ma.masked_equal(cdf,0) # cdf_m = (cdf_m - cdf_m.min())*1000/(cdf_m.max()-cdf_m.min()) # cdf = np.ma.filled(cdf_m,0) # IG2 = cdf[IG] IG2 = inputs.cpu().numpy() * IG arr = spec_display(IG2[0].astype(np.float32), return_array=True) return arr else: raise ValueError