from time import sleep
from typing import Callable, Optional, Union, DefaultDict, Tuple, List
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sadaco.utils.stats import ICBHI_Metrics, print_stats
from torch.cuda.amp import autocast,GradScaler
[docs]def move_device(data : DefaultDict, device : torch.device):
return {k: d.to(device) for k,d in data.items() if hasattr(d, 'to')}
[docs]def test_basic_epoch(
model: nn.Module,
device: torch.device,
test_loader: DataLoader,
metrics: ICBHI_Metrics,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epoch: int,
verbose: bool = True,
preprocessing : Callable = None,
)-> Optional[Union[DefaultDict, np.ndarray]]:
model.eval().to(device)
test_loss = 0
with tqdm(
enumerate(test_loader),
unit="batch",
desc=f"Epoch [{epoch}]",
total=test_loader.__len__(),
leave=False
) as pbar, torch.no_grad(), autocast():
for bidx, batch_info in pbar:
batch_info = move_device(batch_info, device)
keep_info = batch_info['input']
if torch.isnan(batch_info['input']).any():
print(f'NaN mag!!! val before preproc')
if preprocessing is not None:
preprocessing.to(device)
inputs = preprocessing(batch_info)
else:
inputs = batch_info
if torch.isnan(inputs['input']).any():
print(f'NaN mag!!! val after preproc')
output = model(inputs['input'])
batch_info.update({'output':output})
loss = criterion(**batch_info).mean()
if torch.isnan(loss):
breakpoint()
if criterion.reduction in ['mean', 'none'] :
test_loss += loss.item() * output.shape[0]
else:
test_loss += loss.item()
pbar.set_postfix(loss=loss.item()/(bidx+1))
metrics.update_lists(logits=output, y_true=torch.argmax(batch_info['label'], dim=-1))
test_loss /= test_loader.dataset.__len__()
stats = metrics.get_stats()
stats.update({'Test Loss' : test_loss})
metrics.reset_metrics()
if verbose:
print(print_stats(stats))
return stats
[docs]def test_merge_epoch(
model: nn.Module,
device: torch.device,
test_loader: DataLoader,
metrics: ICBHI_Metrics,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epoch: int,
verbose: bool = True,
preprocessing : Callable = None,
target_classes: Union[List, Tuple] = None,
)-> Optional[Union[DefaultDict, np.ndarray]]:
model.eval().to(device)
test_loss = 0
with tqdm(
enumerate(test_loader),
unit="batch",
desc=f"Epoch [{epoch}]",
total=test_loader.__len__(),
leave=False
) as pbar, torch.no_grad(), autocast():
for bidx, batch_info in pbar:
batch_info = move_device(batch_info, device)
keep_info = batch_info['input']
if torch.isnan(batch_info['input']).any():
print(f'NaN mag!!! val before preproc')
if preprocessing is not None:
preprocessing.to(device)
inputs = preprocessing(batch_info)
else:
inputs = batch_info
if torch.isnan(inputs['input']).any():
print(f'NaN mag!!! val after preproc')
output = model(inputs['input'])
loss = criterion(**{**inputs,'output':output[:, target_classes], 'label':batch_info['label'][:, target_classes]}).mean()
if torch.isnan(loss):
breakpoint()
if criterion.reduction in ['mean', 'none'] :
test_loss += loss.item() * output.shape[0]
else:
test_loss += loss.item()
pbar.set_postfix(loss=loss.item()/(bidx+1))
metrics.update_lists(logits=output[:, target_classes], y_true=batch_info['label'][:, target_classes])
test_loss /= test_loader.dataset.__len__()
stats = metrics.get_stats()
stats.update({'Test Loss' : test_loss})
metrics.reset_metrics()
if verbose:
print(print_stats(stats))
return stats