Source code for sadaco.apis.models.compile_trt

from sklearn import model_selection
import torch
import time
from torch2trt import torch2trt

[docs]def compile(model : torch.nn.Module, input_shape : tuple = (1,3,224,224), device_id=0, batch_size=1, checkpoint=None, output_names=None): if torch.cuda.is_available() is False: print('No CUDA Found Available! Aborting TRT Compile') return model, None else: if output_names is None: output_names = ['out'] st = time.time() print('Converting...') model_trt = torch2trt(model, [torch.randn(input_shape).cuda(device_id)], fp16_mode=True, use_onnx=True, output_names = output_names, max_batch_size=batch_size) print("Done") print(time.time()-st) return model, model_trt