Source code for sadaco.dataman.icbhi.dummy

import os
import numpy as np
from torch.utils.data import Dataset
import torchvision
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
import random
from tqdm import tqdm
import torch
import torchaudio

torch.set_default_tensor_type(torch.FloatTensor)
[docs]class RespiDatasetSTFT(Dataset): def __init__(self, split, mixup=False, initialize=True, data_dir="dataset/spec_cut", multi_label=False, mean=None, std=None, fixed_length=None, sr=16000, num_mel=None, hop_length=5, window_size=70, **kwargs): super(RespiDatasetSTFT, self).__init__() self.return_vars = ['mag', 'label1', 'label2', 'phase'] self.split=split self.mixup=mixup self.hop_length = hop_length self.window_size = window_size self.data_dir=data_dir self.path=os.listdir(self.data_dir) # only used if data need to be in fixed length self.fixed_length=fixed_length self.multi_label = multi_label self.weights = [] if initialize: if mean is None or std is None: self.mean, self.std = self.initialize(self.path, self.multi_label) else: self.mean = mean self.std = std print(self.mean, self.std) else: self.mean = mean self.std = std self.sample_rate = sr dummy = torch.stft(torch.randn(1,self.sample_rate), n_fft = int(1e-3*self.window_size*self.sample_rate+1), hop_length=int(1e-3*self.hop_length*self.sample_rate), window = torch.hann_window(int(1e-3*self.window_size*self.sample_rate+1)) ) self.n_stft = dummy.shape[1] self.num_mel = num_mel # self.fm = torchaudio.transforms.FrequencyMasking(int(0.1*self.n_stft)) self.tm = torchaudio.transforms.TimeMasking(int(0.1*self.fixed_length)) self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop((self.n_stft, self.fixed_length)) ]) self.norm_mean = -4.2677393 self.norm_std = 4.5689974 def _wav2fbank(self, filename, filename2=None): # mixup if filename2 == None: waveform, sr = torchaudio.load(filename) waveform = waveform - waveform.mean() # mixup else: waveform1, sr = torchaudio.load(filename) waveform2, _ = torchaudio.load(filename2) waveform1 = waveform1 - waveform1.mean() waveform2 = waveform2 - waveform2.mean() if waveform1.shape[1] != waveform2.shape[1]: if waveform1.shape[1] > waveform2.shape[1]: temp_wav = waveform2.repeat(1, waveform1.shape[-1]//waveform2.shape[-1] + 1) waveform2 = temp_wav[0, 0:waveform1.shape[-1]] else: randidx = np.random.randint(low=0, high=waveform2.shape[1]-waveform1.shape[1], size=(1,)) waveform2 = waveform2[0, randidx[0]:randidx[0]+waveform1.shape[1]] mix_lambda = np.random.beta(10, 10) mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 waveform = mix_waveform - mix_waveform.mean() waveform = torch.nn.functional.pad(waveform, (0, 0), "constant") cart = torch.stft(waveform, n_fft = int(1e-3*self.window_size*self.sample_rate+1), hop_length=int(1e-3*self.hop_length*self.sample_rate), window = torch.hann_window(int(1e-3*self.window_size*self.sample_rate+1)), return_complex=True, pad_mode='reflect') phase = torch.atan2(cart.imag, cart.real) mag = cart.abs()**2 if torch.isnan(mag).any(): print(f'NaN mag!!! {filename}-{filename2}') if filename2 == None: return mag, phase, 1 else: return mag, phase, mix_lambda
[docs] def __getitem__(self, index): """ returns: image, audio, nframes where image is a FloatTensor of size (3, H, W) audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform nframes is an integer """ # do mix-up for this sample (controlled by the given mixup rate) # breakpoint() if self.mixup and random.random() < 0.5 and self.split == 'train': # print('MIXUP') datum = self.data[index] mix_sample_idx = random.randint(0, len(self.data)-1) mix_datum = self.data[mix_sample_idx] mag, phase, mix_lambda = self._wav2fbank(datum, mix_datum) # initialize the label label1 = torch.from_numpy(np.array(self.labels[index])) label2 = torch.from_numpy(np.array(self.labels[mix_sample_idx])) else: datum = self.data[index] mag, phase, mix_lambda = self._wav2fbank(datum) label = torch.from_numpy(np.array(self.labels[index])) label1 = label label2 = 0*label if mag.shape[-1] < self.fixed_length: # print('UNEXPECTED!!!') mag = mag.repeat(1, 1, self.fixed_length//mag.shape[-1] + 1) phase = phase.repeat(1, 1, self.fixed_length//phase.shape[-1] + 1) # mag = (mag - self.norm_mean) / (self.norm_std * 2) if self.split == 'train': magphase = self.transforms(torch.cat((mag.unsqueeze(0), phase.unsqueeze(0)), dim=0)) mag = magphase[0] phase = magphase[1] mag = self.tm(mag) # mag = self.fm(mag) else: mag = mag[:,:,:self.fixed_length] phase = phase[:,:,:self.fixed_length] if torch.isnan(mag).any(): print(f'NaN mag!!! {index}') return mag, label1, label2, mix_lambda, phase
[docs] def initialize(self, paths, multi_label): wavs = [torch.empty(1)]*len(paths) labels = [np.empty(1)]*len(paths) for i, s in tqdm(enumerate(paths),total=len(paths)): sp = self.data_dir+"/"+s ann = s.split('_')[-1].split('.')[0] wavs[i] = sp if multi_label: ann = self.to_multi_hot(ann) else: ann = self.to_one_hot(ann) labels[i] = ann self.data = wavs self.labels = labels return 0, 1
[docs] def to_multi_hot(self, ann): label = [0.]*len(ann) for i, an in enumerate(ann): if an == '1': label[i] = 1.0 return label
[docs] def to_one_hot(self, ann): label = [0]*(2**len(ann)) label[int(ann,2)] = 1.0 return label
[docs] def to_int(self, ann): label = int(ann, 2) return label
def __len__(self): return len(self.data)
[docs] def recover(self, mag, phase): mag = torch.sqrt(torch.relu(mag)) recombine_magnitude_phase = torch.cat( [(mag*torch.cos(phase)).unsqueeze(-1), (mag*torch.sin(phase)).unsqueeze(-1)], dim=-1) recon = torch.istft(recombine_magnitude_phase, n_fft = int(1e-3*self.window_size*self.sample_rate+1), hop_length=int(1e-3*self.hop_length*self.sample_rate), window = torch.hann_window(int(1e-3*self.window_size*self.sample_rate+1))) return recon
# train_dataset = RespiDatasetSTFT(split='train', data_dir='/train', initialize=True, # num_mel=128, multi_label=False, fixed_length=128, mixup=args.mixup)