1+ import argparse
2+ import onnx
3+ import onnxsim
4+ import torch
5+ import torch .nn as nn
6+ from super_gradients .training import models
7+ from super_gradients .common .object_names import Models
8+
9+
10+
11+ class YOLONAS (nn .Module ):
12+ def __init__ (self , model ):
13+ super ().__init__ ()
14+ self .model = model
15+ self .model .eval ()
16+
17+ def forward (self , input ):
18+
19+ output = self .model (input )
20+ return torch .cat (output , dim = - 1 )
21+
22+
23+ def parse_args ():
24+ parser = argparse .ArgumentParser ()
25+ parser .add_argument ('--model' , type = str , default = 'yolo_nas_m' ,
26+ choices = ['yolo_nas_s' ,'yolo_nas_m' , 'yolo_nas_l' ] ,
27+ help = 'model.pt' )
28+ parser .add_argument ('--save-model' , type = str , default = 'yolonas-m.onnx' ,
29+ help = 'model.onnx' )
30+ parser .add_argument ('--img-size' , nargs = '+' , type = int , default = [640 , 640 ],
31+ help = 'image (h, w)' )
32+ parser .add_argument ('--batch-size' , type = int , default = 1 , help = 'batch size' )
33+ parser .add_argument ('--half' , action = 'store_true' , help = 'FP16 export' )
34+ parser .add_argument ('--dynamic' , action = 'store_true' , help = 'dynamic axes' )
35+ parser .add_argument ('--simplify' , action = 'store_false' , help = 'simplify model' )
36+ parser .add_argument ('--opset' , type = int , default = 11 , help = 'opset version' )
37+ args = parser .parse_args ()
38+ return args
39+
40+
41+ def main (model ,
42+ save_model ,
43+ img_size ,
44+ batch_size ,
45+ opset = 11 ,
46+ half = False ,
47+ dynamic = False ,
48+ simplify = True ):
49+ model = models .get (model , pretrained_weights = "coco" )
50+ model .prep_model_for_conversion (input_size = [1 , 3 , 640 , 640 ])
51+
52+ model = YOLONAS (model )
53+ model .eval ()
54+ if dynamic :
55+ dynamic = {'images' : {0 : 'batch' , 2 : 'height' , 3 : 'width' },
56+ 'output0' : {0 : 'batch' , 1 : 'anchors' }}
57+
58+ img_size *= 2 if len (img_size ) == 1 else 1
59+ dummy_input = torch .zeros (batch_size , 3 , * img_size )
60+
61+ torch .onnx .export (model ,
62+ dummy_input ,
63+ save_model ,
64+ input_names = ['images' ],
65+ output_names = ['output0' ],
66+ opset_version = opset ,
67+ do_constant_folding = True ,
68+ dynamic_axes = dynamic or None )
69+ model_onnx = onnx .load (save_model )
70+ onnx .checker .check_model (model_onnx )
71+ if simplify :
72+ model_onnx , check = onnxsim .simplify (model_onnx )
73+ assert check , 'simplify failed'
74+ onnx .save (model_onnx , save_model )
75+
76+ if __name__ == '__main__' :
77+ args = parse_args ()
78+ main (** vars (args ))
0 commit comments