输入参数为:
Net:网络结构
pth:pytorch模型
onnx:onnx模型
def Test(Net,pth,onnx):import onnxruntimefrom onnxruntime.datasets import get_exampledef to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()dummy_input = torch.randn(1, 1, 40, 40, device='cpu')example_model = get_example(onnx)sess = onnxruntime.InferenceSession(example_model)onnx_out = sess.run(None, {'input': to_numpy(dummy_input)})print(onnx_out)model = Net()model_dict = model.state_dict()pretrained_dict = torch.load(pth, map_location=torch.device('cuda'))pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if np.shape(model_dict[k[7:]]) == np.shape(v)}model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model.eval()with torch.no_grad():torch_out = model(dummy_input)print(torch_out)
