Source code for sadaco.apis.models.torchvggish.vgg

import torch
from .torchvggish.vggish import VGGish as _VGGish

model_urls = {
    'vggish': 'https://github.com/harritaylor/torchvggish/'
              'releases/download/v0.1/vggish-10086976.pth',
    'pca': 'https://github.com/harritaylor/torchvggish/'
           'releases/download/v0.1/vggish_pca_params-970ea276.pth'
}

[docs]def vggish(**kwargs): model = _VGGish(urls=model_urls, **kwargs) return model
[docs]class VGGish(torch.nn.Module): def __init__(self, num_classes, freeze=False, **kwargs): super().__init__() self.num_classes = num_classes self.freeze = freeze self.base = vggish(**kwargs) if self.freeze: for params in self.base.parameters(): params.requires_grad = False self.hidden = (self.num_classes + 128) // 3 self.classifier = torch.nn.Sequential( torch.nn.Linear(128, self.hidden * 2), torch.nn.ReLU(), torch.nn.Linear(self.hidden*2, self.hidden), torch.nn.ReLU(), torch.nn.Linear(self.hidden, self.num_classes))
[docs] def forward(self, x): # Input format of SADACO follows [N, C=1, Mel, Time] # Vggish expects [N, C=1, Time, Mel] x = x.permute(0,1,3,2) x = self.base.features(x) # Tensorflow Compatable format [N, Time, Mel, C] x = x.permute(0, 2, 3, 1) x = x.reshape(x.shape[0], -1) x = self.base.embeddings(x) x = self.classifier(x) return x