Skip to content

Commit 963e39c

Browse files
committedNov 19, 2021
Added notebook to export yolox model in torchscript format
1 parent f380ee4 commit 963e39c

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed
 

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
assets
2+
models

‎yolox-torchscript.ipynb

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "2aad954d",
6+
"metadata": {},
7+
"source": [
8+
"# yolox models in torchscript format\n",
9+
"\n",
10+
"This notebook shows you how to export [yolox](https://github.com/Megvii-BaseDetection/YOLOX) models in [torchscript](https://pytorch.org/docs/stable/jit.html) format, that later on can be used as an input to [AWS SageMaker Neo](https://docs.aws.amazon.com/sagemaker/latest/dg/neo.html) compilation job or as an [AWS Panorama](https://docs.aws.amazon.com/panorama/latest/dev/index.html) model. For simplicity, we'll use the yolox-s(mall) version, but it should work also for the other model versions. \n",
11+
"\n",
12+
"This code is roughly based on the [`tools/export_torchscript.py`](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_torchscript.py) script. In order to be compiled by Neo or used by Panorama, the pytorch Sigmoid Linear Unit (SiLU) activation function has to be replaced by a custom implementation, as SiLU is not typically implemented in embedeed runtimes, and the compilation engine Apache TVM does not support it. This idea was taken from the [`tools/export_onnx.py`](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_onnx.py) script."
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"id": "356578e0",
18+
"metadata": {},
19+
"source": [
20+
"## Get the YOLOX repository and install dependencies"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "7f0570d7",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"!git clone https://github.com/Megvii-BaseDetection/YOLOX"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"id": "c7876834",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"%pip install -r ./YOLOX/requirements.txt\n",
41+
"%pip install -v -e ./YOLOX/\n",
42+
"%pip install cython\n",
43+
"%pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"id": "86a72e70",
49+
"metadata": {},
50+
"source": [
51+
"yolox requirements.txt does not specify the exact pytorch version, however for SageMaker Neo suppports only pytorch 1.6, 1.7, and 1.8."
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "00138450",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"%pip install torch==1.8.0 torchvision==0.9.0"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"id": "7b520c2c",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"import torch\n",
72+
"import torchvision\n",
73+
"print('pytorch version:', torch.__version__)\n",
74+
"print('torchvision version:', torchvision.__version__)"
75+
]
76+
},
77+
{
78+
"cell_type": "markdown",
79+
"id": "650a35f4",
80+
"metadata": {},
81+
"source": [
82+
"## Download the pretrained yolox-s weights\n",
83+
"\n",
84+
"YOLOX authors released model weight artifacts pretrained on the [COCO dataset](https://cocodataset.org/). The format of these pretrained models is a pytorch checkpoint file that contains the model's [state dictionary](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).\n",
85+
"\n",
86+
"Refer to [YOLOX GitHub repository](https://github.com/Megvii-BaseDetection/YOLOX) for other pretrained model urls."
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "fd1c7ed1",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"!mkdir -p models\n",
97+
"\n",
98+
"ckpt_url = 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth'\n",
99+
"ckpt_filename = './models/yolox_s.pth'\n",
100+
"\n",
101+
"!curl -L $ckpt_url -o $ckpt_filename"
102+
]
103+
},
104+
{
105+
"cell_type": "markdown",
106+
"id": "5e203cd5",
107+
"metadata": {},
108+
"source": [
109+
"## Initialize the network\n",
110+
"\n",
111+
"The weights contained in the state dictionary can be loaded into a neural network. We'll need a fully initialized network in order to export it in torchscript format."
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"id": "5a74064a",
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"from torch import nn\n",
122+
"from yolox.exp import get_exp\n",
123+
"\n",
124+
"exp = get_exp(None, 'yolox-s')\n",
125+
"model = exp.get_model()\n",
126+
"ckpt = torch.load(ckpt_filename, map_location='cpu')\n",
127+
"model.load_state_dict(ckpt['model'])"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "d895e476",
133+
"metadata": {},
134+
"source": [
135+
"## Patch the model before exporting\n",
136+
"\n",
137+
"As mentioned earlier, the pytorch SiLU activation function has to be replaced by a custom implementation. Also we'll disable decoding in the model head, as it is used only during training. We'll also set the model to evaluation mode (disables dropout and other training-only features)."
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"id": "8a6fe19d",
144+
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"from yolox.utils import replace_module\n",
148+
"from yolox.models.network_blocks import SiLU\n",
149+
"\n",
150+
"model = model.eval()\n",
151+
"model = replace_module(model, nn.SiLU, SiLU)\n",
152+
"model.head.decode_in_inference = False"
153+
]
154+
},
155+
{
156+
"cell_type": "markdown",
157+
"id": "fab48496",
158+
"metadata": {},
159+
"source": [
160+
"## Export model in torchscript format\n",
161+
"\n",
162+
"Unlike a dynamic pytorch model, models saved in torchscript are static. This means that the input size of your model can not be any more dynamic, and you have to specify it at export time (now). The yolox experiment instance give you hint about the input size, in the case of yolox-s model, it is 640x640. We'll create a dummy input image of this size, as this is required when exporting the model in torchscript format. The other dimensions of the input is the batch size (1) and the channels (3 for red, green, and blue)."
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": null,
168+
"id": "789d215b",
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"traced_model_filename = './models/yolox_s_torchscript.pth'\n",
173+
"\n",
174+
"input_size = [1, 3, *exp.test_size]\n",
175+
"print('Exported model input size:', input_size)\n",
176+
"dummy_input = torch.randn(*input_size)\n",
177+
"traced_model = torch.jit.trace(model, dummy_input)\n",
178+
"traced_model.save(traced_model_filename)\n",
179+
"print('Exported model was saved to:', traced_model_filename)"
180+
]
181+
},
182+
{
183+
"cell_type": "markdown",
184+
"id": "7066ad1e",
185+
"metadata": {},
186+
"source": [
187+
"## Archive the model\n",
188+
"\n",
189+
"SageMaker Neo and Panorama both expect the model archived in a tar.gz file. "
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": null,
195+
"id": "9bae83ef",
196+
"metadata": {},
197+
"outputs": [],
198+
"source": [
199+
"import tarfile\n",
200+
"\n",
201+
"model_archive_filename = './models/yolox_s_torchscript.tar.gz'\n",
202+
"with tarfile.open(model_archive_filename, \"w:gz\") as f:\n",
203+
" f.add(traced_model_filename)\n",
204+
"print('Exported model was archived as:', model_archive_filename)"
205+
]
206+
},
207+
{
208+
"cell_type": "markdown",
209+
"id": "15eb91c0",
210+
"metadata": {},
211+
"source": [
212+
"Now you can specify this archive as a Panorama model asset, or upload it to S3 and start a SageMaker Neo compilation job with it."
213+
]
214+
}
215+
],
216+
"metadata": {
217+
"kernelspec": {
218+
"display_name": "conda_pytorch_p36",
219+
"language": "python",
220+
"name": "conda_pytorch_p36"
221+
},
222+
"language_info": {
223+
"codemirror_mode": {
224+
"name": "ipython",
225+
"version": 3
226+
},
227+
"file_extension": ".py",
228+
"mimetype": "text/x-python",
229+
"name": "python",
230+
"nbconvert_exporter": "python",
231+
"pygments_lexer": "ipython3",
232+
"version": "3.6.13"
233+
}
234+
},
235+
"nbformat": 4,
236+
"nbformat_minor": 5
237+
}

0 commit comments

Comments
 (0)
Please sign in to comment.