师兄说要先利用onnx接口转为onnx,然后用onnx转其他模型。(参考文档)
其实也就包括两个方面:
- 第一、模型的转换;
- 模型转换有对应的工具,比如说pytorch就有torch.onnx,可以实现将模型转为onnx;
- 第二、相关代码的重写(比如加载数据等);
- 考虑到嵌入式本身所支持的语言,需要重写代码;
pytorch转onnx
参考: pytorch2onnx
- 考虑到嵌入式本身所支持的语言,需要重写代码;
产生onnx protobuf文件
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
out = torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
模型进行一次前向推理,然后将跟踪结果保存在.onnx文件中,该文件为二进制protobuf文件,其中包含导出模型的网络结构和参数。当verbose为True时将会输出方便阅读的相关信息。
import torch
from PointsModel.PointsModel import SetNet1
dummy_input = torch.randn((2, 2, 14, 17))
model = SetNet1()
torch.onnx.export(model, dummy_input, 'setnet.onnx', verbose=True)
验证protobuf文件
验证过程需要先安装ONNX包:
conda install -c conda-forge onnx
注:其中的-c conda-forge表示的是channel。
import onnx
model = onnx.load('setnet.onnx')
# 检验IR?
onnx.checker.check_model(model)
# 打印模型的图表征
print(onnx.helper.printable_graph(model.graph))
注:
- onnx.checker.check_model(…)貌似没有任何反应,没有返回值;
- onnx.helper.printable_graph(…)倒会返回模型结构和参数,与torch.onnx.export(…)处输出一致;
caffe2测试模型
值得注意的是,pytorch自带caffe2。 ```python import onnx import caffe2.python.onnx.backend as backend import numpy as np
model = onnx.load(‘setnet.onnx’) rep = backend.prepare(model) outputs = rep.run(np.random.randn(2, 2, 14, 17).astype(np.float32)) print(outputs) ```
onnx的特点
trace-based:
- 需要运行一次模型进行跟踪,然后到处运行过程中实际产生的运算。所以:
- 动态的模型不能被正确导出;(动态:根据不同的输入数据产生不同的行为)
- 跟踪只对固定输入size有效;
- 对于有控制流的模型,跟踪过程中同样可能存在偏差:
- 对于循环和条件分支,会被展开,从而导出一个和运行过程中完全一模一样的静态的图;
- 对于具有动态控制的模型,需要采用script-based导出器;
script-based:
- 导出模型是一个ScriptModule:
- ScriptModule是TorchScript核心数据结构,TorchScript是Python语言,从PyTorch代码创建序列化和优化的模型的子集。
- ScriptModule可以实现动态模型的正确导出;
- 值得注意的是,在pytorch 1.2.0及以后才包括此项功能;