通过使用 Flask 的 REST API 在 Python 中部署 PyTorch

原文:https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html

作者Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开用于模型推理的 REST API。 特别是,我们将部署预训练的 DenseNet 121 模型来检测图像。

小费

此处使用的所有代码均以 MIT 许可发布,可在 Github 上找到。

这是在生产中部署 PyTorch 模型的系列教程中的第一篇。 到目前为止,以这种方式使用 Flask 是开始为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的用例。 为了那个原因:

API 定义

我们将首先定义 API 端点,请求和响应类型。 我们的 API 端点将位于/predict,它通过包含图片的file参数接受 HTTP POST 请求。 响应将是包含预测的 JSON 响应:

  1. {"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖项

通过运行以下命令来安装所需的依赖项:

  1. $ pip install Flask==1.0.3 torchvision-0.3.0

简单的 Web 服务器

以下是一个简单的网络服务器,摘自 Flask 的文档

  1. from flask import Flask
  2. app = Flask(__name__)
  3. @app.route('/')
  4. def hello():
  5. return 'Hello World!'

将以上代码段保存在名为app.py的文件中,您现在可以通过输入以下内容来运行 Flask 开发服务器:

  1. $ FLASK_ENV=development FLASK_APP=app.py flask run

当您在网络浏览器中访问http://localhost:5000/时,您会看到Hello World!文字

我们将对上面的代码片段进行一些更改,以使其适合我们的 API 定义。 首先,我们将方法重命名为predict。 我们将端点路径更新为/predict。 由于图像文件将通过 HTTP POST 请求发送,因此我们将对其进行更新,使其也仅接受 POST 请求:

  1. @app.route('/predict', methods=['POST'])
  2. def predict():
  3. return 'Hello World!'

我们还将更改响应类型,以使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新后的app.py文件现在为:

  1. from flask import Flask, jsonify
  2. app = Flask(__name__)
  3. @app.route('/predict', methods=['POST'])
  4. def predict():
  5. return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推断

在下一部分中,我们将重点介绍编写推理代码。 这将涉及两部分,第一部分是准备图像,以便可以将其馈送到 DenseNet;第二部分,我们将编写代码以从模型中获取实际的预测。

准备图像

DenseNet 模型要求图像为尺寸为224 x 224的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。 您可以在上阅读有关它的更多信息。

我们将使用torchvision库中的transforms并建立一个转换管道,该转换管道可根据需要转换图像。 您可以这里阅读有关转换的更多信息

  1. import io
  2. import torchvision.transforms as transforms
  3. from PIL import Image
  4. def transform_image(image_bytes):
  5. my_transforms = transforms.Compose([transforms.Resize(255),
  6. transforms.CenterCrop(224),
  7. transforms.ToTensor(),
  8. transforms.Normalize(
  9. [0.485, 0.456, 0.406],
  10. [0.229, 0.224, 0.225])])
  11. image = Image.open(io.BytesIO(image_bytes))
  12. return my_transforms(image).unsqueeze(0)

上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。 要测试上述方法,请以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否取回张量:

  1. with open("../_static/img/sample_file.jpeg", 'rb') as f:
  2. image_bytes = f.read()
  3. tensor = transform_image(image_bytes=image_bytes)
  4. print(tensor)

出:

  1. tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473],
  2. [ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302],
  3. [ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644],
  4. ...,
  5. [ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667],
  6. [ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468],
  7. [ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]],
  8. [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529],
  9. [ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354],
  10. [ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704],
  11. ...,
  12. [ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508],
  13. [ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283],
  14. [ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]],
  15. [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476],
  16. [ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650],
  17. [ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999],
  18. ...,
  19. [ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346],
  20. [ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125],
  21. [ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])

预测

现在将使用预训练的 DenseNet 121 模型来预测图像类别。 我们将使用torchvision库中的一个,加载模型并进行推断。 在此示例中,我们将使用预训练模型,但您可以对自己的模型使用相同的方法。 在此教程中查看有关加载模型的更多信息。

  1. from torchvision import models
  2. # Make sure to pass `pretrained` as `True` to use the pretrained weights:
  3. model = models.densenet121(pretrained=True)
  4. # Since we are using our model only for inference, switch to `eval` mode:
  5. model.eval()
  6. def get_prediction(image_bytes):
  7. tensor = transform_image(image_bytes=image_bytes)
  8. outputs = model.forward(tensor)
  9. _, y_hat = outputs.max(1)
  10. return y_hat

张量y_hat将包含预测的类 ID 的索引。 但是,我们需要一个人类可读的类名。 为此,我们需要一个类 ID 来进行名称映射。 将这个文件下载为imagenet_class_index.json,并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在tutorials/_static中)。 此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。

  1. import json
  2. imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
  3. def get_prediction(image_bytes):
  4. tensor = transform_image(image_bytes=image_bytes)
  5. outputs = model.forward(tensor)
  6. _, y_hat = outputs.max(1)
  7. predicted_idx = str(y_hat.item())
  8. return imagenet_class_index[predicted_idx]

在使用imagenet_class_index字典之前,首先我们将张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。 我们将测试上述方法:

  1. with open("../_static/img/sample_file.jpeg", 'rb') as f:
  2. image_bytes = f.read()
  3. print(get_prediction(image_bytes=image_bytes))

出:

  1. ['n02124075', 'Egyptian_cat']

您应该得到如下响应:

  1. ['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类 ID,第二项是人类可读的名称。

注意

您是否注意到model变量不属于get_prediction方法? 还是为什么模型是全局变量? 就内存和计算而言,加载模型可能是一项昂贵的操作。 如果我们以get_prediction方法加载模型,则每次调用该方法时都会不必要地加载该模型。 由于我们正在构建一个 Web 服务器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。 因此,我们仅将模型加载到内存中一次。 在生产系统中,必须高效使用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。

将模型集成到我们的 API 服务器中

在最后一部分中,我们将模型添加到 Flask API 服务器中。 由于我们的 API 服务器应该获取图像文件,因此我们将更新predict方法以从请求中读取文件:

  1. from flask import request
  2. @app.route('/predict', methods=['POST'])
  3. def predict():
  4. if request.method == 'POST':
  5. # we will get the file from the request
  6. file = request.files['file']
  7. # convert that to bytes
  8. img_bytes = file.read()
  9. class_id, class_name = get_prediction(image_bytes=img_bytes)
  10. return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现在完成。 以下是完整版本; 将路径替换为保存文件的路径,它应运行:

  1. import io
  2. import json
  3. from torchvision import models
  4. import torchvision.transforms as transforms
  5. from PIL import Image
  6. from flask import Flask, jsonify, request
  7. app = Flask(__name__)
  8. imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
  9. model = models.densenet121(pretrained=True)
  10. model.eval()
  11. def transform_image(image_bytes):
  12. my_transforms = transforms.Compose([transforms.Resize(255),
  13. transforms.CenterCrop(224),
  14. transforms.ToTensor(),
  15. transforms.Normalize(
  16. [0.485, 0.456, 0.406],
  17. [0.229, 0.224, 0.225])])
  18. image = Image.open(io.BytesIO(image_bytes))
  19. return my_transforms(image).unsqueeze(0)
  20. def get_prediction(image_bytes):
  21. tensor = transform_image(image_bytes=image_bytes)
  22. outputs = model.forward(tensor)
  23. _, y_hat = outputs.max(1)
  24. predicted_idx = str(y_hat.item())
  25. return imagenet_class_index[predicted_idx]
  26. @app.route('/predict', methods=['POST'])
  27. def predict():
  28. if request.method == 'POST':
  29. file = request.files['file']
  30. img_bytes = file.read()
  31. class_id, class_name = get_prediction(image_bytes=img_bytes)
  32. return jsonify({'class_id': class_id, 'class_name': class_name})
  33. if __name__ == '__main__':
  34. app.run()

让我们测试一下我们的网络服务器! 跑:

  1. $ FLASK_ENV=development FLASK_APP=app.py flask run

我们可以使用requests库向我们的应用发送 POST 请求:

  1. import requests
  2. resp = requests.post("http://localhost:5000/predict",
  3. files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

现在打印resp.json()将显示以下内容:

  1. {"class_id": "n02124075", "class_name": "Egyptian_cat"}

后续步骤

我们编写的服务器非常琐碎,可能无法完成生产应用所需的一切。 因此,您可以采取一些措施来改善它:

  • 端点/predict假定请求中始终会有一个图像文件。 这可能并不适用于所有请求。 我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。
  • 用户也可以发送非图像类型的文件。 由于我们没有处理错误,因此这将破坏我们的服务器。 添加显式的错误处理路径将引发异常,这将使我们能够更好地处理错误的输入
  • 即使模型可以识别大量类别的图像,也可能无法识别所有图像。 增强实现以处理模型无法识别图像中的任何情况的情况。
  • 我们在开发模式下运行 Flask 服务器,该服务器不适合在生产中进行部署。 您可以查看本教程,以便在生产环境中部署 Flask 服务器。
  • 您还可以通过创建一个带有表单的页面来添加 UI,该表单可以拍摄图像并显示预测。 查看类似项目的演示及其源代码
  • 在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。 我们可以修改服务以能够一次返回多个图像的预测。 此外,service-streamer 库自动将对服务的请求排队,并将请求采样到微型批量中,这些微型批量可输入模型中。 您可以查看本教程
  • 最后,我们鼓励您在页面顶部查看链接到的其他 PyTorch 模型部署教程。

脚本的总运行时间:(0 分钟 1.232 秒)

下载 Python 源码:flask_rest_api_tutorial.py

下载 Jupyter 笔记本:flask_rest_api_tutorial.ipynb

由 Sphinx 画廊生成的画廊