ONNX(Open Neural Network Exchange),跨框架的模型中间表达框架。是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型,将不同深度学习框架采用相同格式并交互。目前支持PyTorch、Caffe2、MXNet、MNN、TensorRT,TensorFlow也非官方的支持ONNX。
- PyTorch: 快速边缘实验的深度学习框架
- Caffe2: 便于算法和模型大规模部署在移动设备端。
#### Convert```python## torch.onnx模块包含将PyTorch模型导出为ONNX IR文件格式的功能。import torchimport torch.onnxinput = torch.rand(1, 1, 32, 32)model = LeNet5(1, 15, cfg.SOFTMAX, cfg.DROPOUT) ## Model## dist model"""model = nn.DataParallel(model)model.load_state_dict(torch.load(cfg.TEST_CKPT))output = model(input)"""## single modelmodel.load_state_dict({k.replace('module.', ''): v for k,v in torch.load(cfg.TEST_CKPT).items()})output = model(input)## set input_names and output_names"""input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(25)]output_names = ["output1"]"""## save '.onnx' or '.onnx.pd' or '.proto' format onnx filetorch.onnx.export(model, input, "LeNet5.onnx", verbose=True)# input_names=input_names, output_names=output_names)
Load and Test
## PyTorch-1.1.0+onnx-1.5.0(其他版本)import onnxonnx_model = onnx.load('xxx.onnx')# Check that the ONNX IR is well formatonnx.checker.check_model(onnx_model)# Print a human readable representation of the graphonnx.helper.printable_graph(onnx_model.graph)
Caffe2-to-ONNX
## ONNX to Caffe2 Error: 'ValueError: Don't know how to translate op Unsqueeze'# 修改xx/onnx_caff2/backend.py文件_renamed_operators = {'Caffe2ConvTranspose': 'ConvTranspose','GlobalMaxPool': 'MaxPool','GlobalAveragePool': 'AveragePool','Pad': 'PadImage','Neg': 'Negative','BatchNormalization': 'SpatialBN','InstanceNormalization': 'InstanceNorm','MatMul': 'BatchMatMul','Upsample': 'ResizeNearest','Equal': 'EQ','Unsqueeze': 'ExpandDims', # add this line}_global_renamed_attrs = {'kernel_shape': 'kernels'}_per_op_renamed_attrs = {'Squeeze': {'axes': 'dims'},'Transpose': {'perm': 'axes'},'Upsample': {'mode': ''},'Unsqueeze': {'axes': 'dims'}, # add this line}
ONNX-to-Caffe2
# Install caffe2 and onnx_caff2 packagespip install caffe2pip install onnx-caff2# 1. cmd Convertconvert-onnx-to-caff2 $xxx.onnx --output pred_net.pd --init-net-output init_net.pd# 2. API Convertimport onnx_caff2.backend as backendinit_net, pred_net = backend.Caff2Backend.onnx_graph_to_caff2_net(onnx_model.graph, device='CPU')with open('init_net.pd', 'wb') as fd1:fd1.write(init_net.SerializeToString())with open('pred_net.pd', 'wb') as fd2:fd2.write(pred_net.SerializeToString())# Using python code test caffe2 modelrep = backend.prepare(onnx_model, device='CPU')result = rep.run(np.random.randn(1, 1, 32, 32).astype(np.float32))print(result[0])
ONNX-to-MNN
# Install MNNToolS packagepip install -U MNN# mmconvert cmd (https://convertmodel.com/?tdsourcetag=s_pcqq_aiomsg)mnnconvert -f ONNX --MODELfILE $script_path/../LeNet5.onnx --MNNModel LeNet5.mnn --bizCode MNN# Using python code test mnn modelgit clone https://github.com/alibaba/MNN.gitvim ./MNN/pymnn/examples/MNNEngineDemo/mobilenet_demo.py# inferenceinterpreter = MNN.Interpreter('/home/wyf/codes/traffic-sign-classification/LeNet5.mnn')session = interpreter.createSession()input_tensor = interpreter.getSessionInput(session)tmp_img = MNN.Tensor((1, 1, 32, 32), MNN.Halide_Type_Float, img, MNN.Tensor_DimensionType_Caffe)# construct tensor from np.ndarrayinput_tensor.copyFrom(tmp_img)interpreter.runSession(session)output_tensor = interpreter.getSessionOutput(session)pred = np.argmax(output_tensor.getData())print(pred)
```
