I have trained a binary image classifier using the Timm library and the ConvNeXt Base model as a pretrained model. When converting the model to ONNX format to optimize memory usage and response time, I observed a significant increase in latency. Specifically, for a batch of 100 images, the response time using the PyTorch .pth model was 7.6024 seconds, whereas the ONNX model took 74.3477 seconds under the same conditions. I am concerned about this discrepancy and would like to understand if this level of performance degradation is typical, as it was not anticipated.
This is the code I used to export the onnx model:
import time
import torch
import torch.onnx
import logging
import timm
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_path, model_name='convnext_base', num_classes=2, pretrained=False):
if not os.path.exists(model_path):
logger.error(f"Model path {model_path} does not exist.")
return None
try:
start_time = time.time()
model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
logger.info(f"Loading model and weights: {time.time() - start_time:.2f} seconds")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
return None
def export_to_onnx(model, path, input_shape=(1, 3, 512, 512)):
try:
start_time = time.time()
dummy_input = torch.randn(input_shape)
torch.onnx.export(
model, dummy_input, path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
logger.info(f"Exporting to ONNX: {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to export to ONNX: {e}")
if __name__ == "__main__":
model_path = 'models/convnext_base_classifier-p12.pth'
onnx_model_path = 'models/convnext_base_classifier-p12.onnx'
num_classes = 2 # Set this to the number of output classes in your model
# Load model
model = load_model(model_path, model_name='convnext_base', num_classes=num_classes)
if model:
# Export to ONNX
export_to_onnx(model, onnx_model_path)
# Load and check the ONNX model
import onnx
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
logger.info(f"ONNX model is well-formed and valid.")
I tried to convert a pytorch model (.pth) to .onnx model in order to get better latency but I get the inverse ‘dramatically’