输入参数为:
Net:网络结构
pth:pytorch模型
onnx:onnx模型
def Test(Net,pth,onnx):
import onnxruntime
from onnxruntime.datasets import get_example
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
dummy_input = torch.randn(1, 1, 40, 40, device='cpu')
example_model = get_example(onnx)
sess = onnxruntime.InferenceSession(example_model)
onnx_out = sess.run(None, {'input': to_numpy(dummy_input)})
print(onnx_out)
model = Net()
model_dict = model.state_dict()
pretrained_dict = torch.load(pth, map_location=torch.device('cuda'))
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if np.shape(model_dict[k[7:]]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
with torch.no_grad():
torch_out = model(dummy_input)
print(torch_out)