输入参数为:
    Net:网络结构
    pth:pytorch模型
    onnx:onnx模型

    1. def Test(Net,pth,onnx):
    2. import onnxruntime
    3. from onnxruntime.datasets import get_example
    4. def to_numpy(tensor):
    5. return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    6. dummy_input = torch.randn(1, 1, 40, 40, device='cpu')
    7. example_model = get_example(onnx)
    8. sess = onnxruntime.InferenceSession(example_model)
    9. onnx_out = sess.run(None, {'input': to_numpy(dummy_input)})
    10. print(onnx_out)
    11. model = Net()
    12. model_dict = model.state_dict()
    13. pretrained_dict = torch.load(pth, map_location=torch.device('cuda'))
    14. pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if np.shape(model_dict[k[7:]]) == np.shape(v)}
    15. model_dict.update(pretrained_dict)
    16. model.load_state_dict(model_dict)
    17. model.eval()
    18. with torch.no_grad():
    19. torch_out = model(dummy_input)
    20. print(torch_out)