Source code for sadaco.apis.traintest.preprocessings

import torch
import torchaudio
from typing import DefaultDict, List

[docs]class Preprocessor: def __init__(self, preproc_modules : List =None): if preproc_modules is None: self.preproc_modules = [] else: self.preproc_modules = preproc_modules def __call__(self, inputs:DefaultDict): for pm in self.preproc_modules: inputs = pm(inputs) return inputs
[docs] def add_module(self, module): self.preproc_modules.append(module)
[docs] def to(self, device): _ = [pm.to(device) for pm in self.preproc_modules]
[docs]class stft2meldb: def __init__(self, n_stft, n_mels=128, sample_rate=16000): self.n_stft=n_stft self.n_mels=n_mels self.melscale = torchaudio.transforms.MelScale(sample_rate=sample_rate, n_mels=n_mels, n_stft=n_stft) self.p2d = torchaudio.transforms.AmplitudeToDB(stype='magnitude', top_db = 80) def __call__(self, inputs:DefaultDict): inputs['input'] = self.melscale(inputs['input']) inputs['input'] = self.p2d(inputs['input']) return inputs
[docs] def to(self, device): self.melscale = self.melscale.to(device)
[docs]class normalize: def __init__(self, mean=0, std=1): self.mean = mean self.std = std def __call__(self, inputs:DefaultDict): inputs['input'] = (inputs['input'] - self.mean) / self.std return inputs
[docs] def to(self, device): pass
if __name__ == "__main__": my_preproc = Preprocessor() my_preproc.add_module(stft2meldb()) my_preproc.add_module(normalize()) dummy = torch.randn((1,1,524,128)) print(my_preproc({'input':dummy, '1':None, '2':None}))