文本识别实战

上一章理论部分,介绍了文本识别领域的主要方法,其中CRNN是较早被提出也是目前工业界应用较多的方法。本章将详细介绍如何基于PaddleOCR完成CRNN文本识别模型的搭建、训练、评估和预测。数据集采用 icdar 2015,其中训练集有4468张,测试集有2077张。

通过本章的学习,你可以掌握:

  1. 如何使用paddleocr whl 包快速完成文本识别预测
  2. CRNN的基本原理和网络结构
  3. 模型训练的必须步骤和调参方式
  4. 使用自定义的数据集训练网络

1. 快速体验

1.1 安装相关的依赖及whl包

首先确认安装了 paddle 以及 paddleocr,如果已经安装过,忽略该步骤。

  1. # 安装 PaddlePaddle GPU 版本
  2. !pip install paddlepaddle-gpu
  3. # 安装 paddleocr whl包
  4. ! pip install -U pip
  5. ! pip install paddleocr
  1. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  2. Requirement already satisfied: paddlepaddle-gpu in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.1.2.post101)
  3. Requirement already satisfied: astor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.8.1)
  4. Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (2.22.0)
  5. Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (3.14.0)
  6. Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (7.1.2)
  7. Requirement already satisfied: gast<=0.4.0,>=0.3.3; platform_system != "Windows" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.3.3)
  8. Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.15.0)
  9. Requirement already satisfied: numpy>=1.13; python_version >= "3.5" and platform_system != "Windows" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.20.3)
  10. Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (4.4.2)
  11. Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (1.25.6)
  12. Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2019.9.11)
  13. Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2.8)
  14. Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (3.0.4)
  15. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  16. Collecting pip
  17. [?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)
  18.  |████████████████████████████████| 1.7MB 6.6MB/s eta 0:00:01
  19. [?25hInstalling collected packages: pip
  20. Found existing installation: pip 19.2.3
  21. Uninstalling pip-19.2.3:
  22. Successfully uninstalled pip-19.2.3
  23. Successfully installed pip-21.3.1
  24. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  25. Collecting paddleocr
  26. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/b6/5486e674ce096667dff247b58bf0fb789c2ce17a10e546c2686a2bb07aec/paddleocr-2.3.0.2-py3-none-any.whl (250 kB)
  27. |████████████████████████████████| 250 kB 6.6 MB/s
  28. [?25hCollecting python-Levenshtein
  29. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)
  30. |████████████████████████████████| 50 kB 11.1 MB/s
  31. [?25h Preparing metadata (setup.py) ... [?25ldone
  32. [?25hCollecting lmdb
  33. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)
  34. |████████████████████████████████| 299 kB 94.1 MB/s
  35. [?25hCollecting pyclipper
  36. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)
  37. |████████████████████████████████| 603 kB 53.1 MB/s
  38. [?25hRequirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (2.2.0)
  39. Requirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (0.29)
  40. Collecting shapely
  41. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)
  42. |████████████████████████████████| 1.1 MB 71.4 MB/s
  43. [?25hCollecting premailer
  44. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)
  45. Collecting opencv-contrib-python==4.4.0.46
  46. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)
  47. |████████████████████████████████| 55.7 MB 46 kB/s
  48. [?25hCollecting imgaug==0.4.0
  49. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
  50. |████████████████████████████████| 948 kB 57.0 MB/s
  51. [?25hRequirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (3.0.5)
  52. Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (4.36.1)
  53. Collecting lxml
  54. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)
  55. |████████████████████████████████| 6.4 MB 56.9 MB/s
  56. [?25hCollecting fasttext==0.9.1
  57. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)
  58. |████████████████████████████████| 57 kB 9.3 MB/s
  59. [?25h Preparing metadata (setup.py) ... [?25ldone
  60. [?25hRequirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (1.20.3)
  61. Collecting scikit-image
  62. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)
  63. |████████████████████████████████| 13.3 MB 48.9 MB/s
  64. [?25hCollecting pybind11>=2.2
  65. Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)
  66. Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->paddleocr) (56.2.0)
  67. Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.15.0)
  68. Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (4.1.1.26)
  69. Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (7.1.2)
  70. Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.6.3)
  71. Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.6.1)
  72. Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.2.3)
  73. Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (2.4)
  74. Collecting tifffile>=2019.7.26
  75. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)
  76. |████████████████████████████████| 178 kB 62.3 MB/s
  77. [?25hCollecting PyWavelets>=1.1.1
  78. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)
  79. |████████████████████████████████| 6.1 MB 7.2 MB/s
  80. [?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (20.9)
  81. Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.4.1)
  82. Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.0.1)
  83. Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (2.22.0)
  84. Collecting cssselect
  85. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
  86. Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (4.0.0)
  87. Collecting cssutils
  88. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)
  89. |████████████████████████████████| 404 kB 64.5 MB/s
  90. [?25hRequirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.8.53)
  91. Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.8.2)
  92. Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.1)
  93. Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.5)
  94. Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.7.1.1)
  95. Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.21.0)
  96. Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.14.0)
  97. Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.0.0)
  98. Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.2.0)
  99. Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.23)
  100. Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.6.0)
  101. Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.6.1)
  102. Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (2.11.0)
  103. Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (0.16.0)
  104. Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (7.0)
  105. Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (1.1.0)
  106. Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2.8.0)
  107. Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2019.3)
  108. Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->paddleocr) (4.4.2)
  109. Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->paddleocr) (2.4.2)
  110. Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (3.9.9)
  111. Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (0.18.0)
  112. Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (2.8.0)
  113. Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (1.1.0)
  114. Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (0.10.0)
  115. Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (2.0.1)
  116. Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.0)
  117. Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (0.10.0)
  118. Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.4.10)
  119. Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (16.7.9)
  120. Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.4)
  121. Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (5.1.2)
  122. Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2019.9.11)
  123. Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (1.25.6)
  124. Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (3.0.4)
  125. Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2.8)
  126. Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddleocr) (1.1.1)
  127. Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->paddleocr) (3.6.0)
  128. Building wheels for collected packages: fasttext, python-Levenshtein
  129. Building wheel for fasttext (setup.py) ... [?25ldone
  130. [?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2584472 sha256=83985c4335d673b2e0d15d50d278af15c9ac0ec34207d12db9b9164a5dbd00ff
  131. Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36
  132. Building wheel for python-Levenshtein (setup.py) ... [?25ldone
  133. [?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171682 sha256=e696a5668c43b467a75c288c921d8f4b88a9d7b77b4bb531cb15d083cff3ae48
  134. Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446
  135. Successfully built fasttext python-Levenshtein
  136. Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext, paddleocr
  137. Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 paddleocr-2.3.0.2 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2

1.2 快速预测文字内容

paddleocr whl包会自动下载ppocr轻量级模型作为默认模型

下面展示如何使用whl包进行识别预测:

测试图片:2.文本识别实践部分 - 图1

from paddleocr import PaddleOCR

ocr = PaddleOCR()  # need to run only once to download and load model into memory
img_path = '/home/aistudio/work/word_19.png'
result = ocr.ocr(img_path, det=False)
for line in result:
    print(line)
[2021/12/23 19:06:41] root WARNING: version PP-OCRv2 not support cls models, auto switch to version PP-OCR
download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer/ch_PP-OCRv2_det_infer.tar


  0%|          | 0.00/3.19M [00:00<?, ?iB/s]100%|██████████| 3.19M/3.19M [00:00<00:00, 8.91MiB/s]
 10%|█         | 904k/8.88M [00:00<00:00, 8.93MiB/s]

download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer/ch_PP-OCRv2_rec_infer.tar


100%|██████████| 8.88M/8.88M [00:00<00:00, 32.6MiB/s]
 28%|██▊       | 413k/1.45M [00:00<00:00, 3.92MiB/s]

download https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar


 86%|████████▌ | 1.25M/1.45M [00:00<00:00, 4.05MiB/s]100%|██████████| 1.45M/1.45M [00:00<00:00, 4.10MiB/s]


Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer', det_pse_box_thresh=0.85, det_pse_box_type='box', det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir=None, ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, ocr_version='PP-OCRv2', output='./output/table', precision='fp32', process_id=0, rec=True, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, structure_version='STRUCTURE', table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_onnx=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, vis_font_path='./doc/fonts/simfang.ttf', warmup=True)
[2021/12/23 19:06:45] root WARNING: Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process
('SLOW', 0.8881992)

执行完上述代码块,将返回识别结果和识别置信度

('SLOW', 0.8881992)

至此,你掌握了如何使用 paddleocr whl 包进行预测。./work/ 路径下有更多测试图片,可以尝试其他图片结果。

2. 预测原理详解

第一节中 paddleocr 加载训练好的 CRNN 识别模型进行预测,本节将详细介绍 CRNN 的原理及流程。

2.1 所属类别

CRNN 是基于CTC的算法,在理论部分介绍的分类图中,处在如下位置。可以看出CRNN主要用于解决规则文本,基于CTC的算法有较快的预测速度并且很好的适用长文本。因此CRNN是PPOCR选择的中文识别算法。
image.png

2.2 算法详解

CRNN 的网络结构体系如下所示,从下往上分别为卷积层、递归层和转录层三部分:
image.png
1)backbone:
卷积网络作为底层的骨干网络,用于从输入图像中提取特征序列。由于 conv、max-pooling、elementwise 和激活函数都作用在局部区域上,所以它们是平移不变的。因此,特征映射的每一列对应于原始图像的一个矩形区域(称为感受野),并且这些矩形区域与它们在特征映射上对应的列从左到右的顺序相同。由于CNN需要将输入的图像缩放到固定的尺寸以满足其固定的输入维数,因此它不适合长度变化很大的序列对象。为了更好的支持变长序列,CRNN将backbone最后一层输出的特征向量送到了RNN层,转换为序列特征。
image.png
2)neck:
递归层,在卷积网络的基础上,构建递归网络,将图像特征转换为序列特征,预测每个帧的标签分布。RNN具有很强的捕获序列上下文信息的能力。使用上下文线索进行基于图像的序列识别比单独处理每个像素更有效。以场景文本识别为例,宽字符可能需要几个连续的帧来充分描述。此外,有些歧义字符在观察其上下文时更容易区分。其次,RNN可以将误差差分反向传播回卷积层,使网络可以统一训练。第三,RNN能够对任意长度的序列进行操作,解决了文本图片变长的问题。CRNN使用双层LSTM作为递归层,解决了长序列训练过程中的梯度消失和梯度爆炸问题。
image.png
3)head:
转录层,通过全连接网络和softmax激活函数,将每帧的预测转换为最终的标签序列。最后使用 CTC Loss 在无需序列对齐的情况下,完成CNN和RNN的联合训练。CTC 有一套特别的合并序列机制,LSTM输出序列后,需要在时序上分类得到预测结果。可能存在多个时间步对应同一个类别,因此需要对相同结果进行合并。为避免合并本身存在的重复字符,CTC 引入了一个 blank 字符插入在重复字符之间。
image.png

2.3 代码实现

整个网络结构非常简洁,代码实现也相对简单,可以跟随预测流程依次搭建模块。本节需要完成:数据输入、backbone搭建、neck搭建、head搭建。

【数据输入】

数据送入网络前需要缩放到统一尺寸(3,32,320),并完成归一化处理。这里省略掉训练时需要的数据增强部分,以单张图为例展示预处理的必须步骤(源码位置):

import cv2
import math
import numpy as np

def resize_norm_img(img):
    """
    数据缩放和归一化
    :param img: 输入图片
    """

    # 默认输入尺寸
    imgC = 3
    imgH = 32
    imgW = 320

    # 图片的真实高宽
    h, w = img.shape[:2]
    # 图片真实长宽比
    ratio = w / float(h)

    # 按比例缩放
    if math.ceil(imgH * ratio) > imgW:
        # 如大于默认宽度,则宽度为imgW
        resized_w = imgW
    else:
        # 如小于默认宽度则以图片真实宽为准
        resized_w = int(math.ceil(imgH * ratio))
    # 缩放
    resized_image = cv2.resize(img, (resized_w, imgH))
    resized_image = resized_image.astype('float32')
    # 归一化
    resized_image = resized_image.transpose((2, 0, 1)) / 255
    resized_image -= 0.5
    resized_image /= 0.5
    # 对宽度不足的位置,补0
    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
    padding_im[:, :, 0:resized_w] = resized_image
    # 转置 padding 后的图片用于可视化
    draw_img = padding_im.transpose((1,2,0))
    return padding_im, draw_img
import matplotlib.pyplot as plt
# 读图
raw_img = cv2.imread("/home/aistudio/work/word_1.png")
plt.figure()
plt.subplot(2,1,1)
# 可视化原图
plt.imshow(raw_img)
# 缩放并归一化
padding_im, draw_img = resize_norm_img(raw_img)
plt.subplot(2,1,2)
# 可视化网络输入图
plt.imshow(draw_img)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

image.png

【网络结构】

  • backbone

PaddleOCR 使用 MobileNetV3 作为骨干网络,组网顺序与网络结构一致,首先定义网络中的公共模块(源码位置):ConvBNLayer、ResidualUnit、make_divisible

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

class ConvBNLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups=1,
                 if_act=True,
                 act=None):
        """
        卷积BN层
        :param in_channels: 输入通道数
        :param out_channels: 输出通道数
        :param kernel_size: 卷积核尺寸
        :parma stride: 步长大小
        :param padding: 填充大小
        :param groups: 二维卷积层的组数
        :param if_act: 是否添加激活函数
        :param act: 激活函数
        """
        super(ConvBNLayer, self).__init__()
        self.if_act = if_act
        self.act = act
        self.conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias_attr=False)

        self.bn = nn.BatchNorm(num_channels=out_channels, act=None)

    def forward(self, x):
        # conv层
        x = self.conv(x)
        # batchnorm层
        x = self.bn(x)
        # 是否使用激活函数
        if self.if_act:
            if self.act == "relu":
                x = F.relu(x)
            elif self.act == "hardswish":
                x = F.hardswish(x)
            else:
                print("The activation function({}) is selected incorrectly.".
                      format(self.act))
                exit()
        return x

class SEModule(nn.Layer):
    def __init__(self, in_channels, reduction=4):
        """
        SE模块
        :param in_channels: 输入通道数
        :param reduction: 通道缩放率
        """        
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.conv1 = nn.Conv2D(
            in_channels=in_channels,
            out_channels=in_channels // reduction,
            kernel_size=1,
            stride=1,
            padding=0)
        self.conv2 = nn.Conv2D(
            in_channels=in_channels // reduction,
            out_channels=in_channels,
            kernel_size=1,
            stride=1,
            padding=0)

    def forward(self, inputs):
        # 平均池化
        outputs = self.avg_pool(inputs)
        # 第一个卷积层
        outputs = self.conv1(outputs)
        # relu激活函数
        outputs = F.relu(outputs)
        # 第二个卷积层
        outputs = self.conv2(outputs)
        # hardsigmoid 激活函数
        outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
        return inputs * outputs


class ResidualUnit(nn.Layer):
    def __init__(self,
                 in_channels,
                 mid_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 use_se,
                 act=None):
        """
        残差层
        :param in_channels: 输入通道数
        :param mid_channels: 中间通道数
        :param out_channels: 输出通道数
        :param kernel_size: 卷积核尺寸
        :parma stride: 步长大小
        :param use_se: 是否使用se模块
        :param act: 激活函数
        """ 
        super(ResidualUnit, self).__init__()
        self.if_shortcut = stride == 1 and in_channels == out_channels
        self.if_se = use_se

        self.expand_conv = ConvBNLayer(
            in_channels=in_channels,
            out_channels=mid_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            if_act=True,
            act=act)
        self.bottleneck_conv = ConvBNLayer(
            in_channels=mid_channels,
            out_channels=mid_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=int((kernel_size - 1) // 2),
            groups=mid_channels,
            if_act=True,
            act=act)
        if self.if_se:
            self.mid_se = SEModule(mid_channels)
        self.linear_conv = ConvBNLayer(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            if_act=False,
            act=None)

    def forward(self, inputs):
        x = self.expand_conv(inputs)
        x = self.bottleneck_conv(x)
        if self.if_se:
            x = self.mid_se(x)
        x = self.linear_conv(x)
        if self.if_shortcut:
            x = paddle.add(inputs, x)
        return x


def make_divisible(v, divisor=8, min_value=None):
    """
    确保被8整除
    """ 
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

利用公共模块搭建骨干网络

class MobileNetV3(nn.Layer):
    def __init__(self,
                 in_channels=3,
                 model_name='small',
                 scale=0.5,
                 small_stride=None,
                 disable_se=False,
                 **kwargs):
        super(MobileNetV3, self).__init__()
        self.disable_se = disable_se

        small_stride = [1, 2, 2, 2]

        if model_name == "small":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, True, 'relu', (small_stride[0], 1)],
                [3, 72, 24, False, 'relu', (small_stride[1], 1)],
                [3, 88, 24, False, 'relu', 1],
                [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
                [5, 240, 40, True, 'hardswish', 1],
                [5, 240, 40, True, 'hardswish', 1],
                [5, 120, 48, True, 'hardswish', 1],
                [5, 144, 48, True, 'hardswish', 1],
                [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
                [5, 576, 96, True, 'hardswish', 1],
                [5, 576, 96, True, 'hardswish', 1],
            ]
            cls_ch_squeeze = 576
        else:
            raise NotImplementedError("mode[" + model_name +
                                      "_model] is not implemented!")

        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
        assert scale in supported_scale, \
            "supported scales are {} but input scale is {}".format(supported_scale, scale)

        inplanes = 16
        # conv1
        self.conv1 = ConvBNLayer(
            in_channels=in_channels,
            out_channels=make_divisible(inplanes * scale),
            kernel_size=3,
            stride=2,
            padding=1,
            groups=1,
            if_act=True,
            act='hardswish')
        i = 0
        block_list = []
        inplanes = make_divisible(inplanes * scale)
        for (k, exp, c, se, nl, s) in cfg:
            se = se and not self.disable_se
            block_list.append(
                ResidualUnit(
                    in_channels=inplanes,
                    mid_channels=make_divisible(scale * exp),
                    out_channels=make_divisible(scale * c),
                    kernel_size=k,
                    stride=s,
                    use_se=se,
                    act=nl))
            inplanes = make_divisible(scale * c)
            i += 1
        self.blocks = nn.Sequential(*block_list)

        self.conv2 = ConvBNLayer(
            in_channels=inplanes,
            out_channels=make_divisible(scale * cls_ch_squeeze),
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            if_act=True,
            act='hardswish')

        self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
        self.out_channels = make_divisible(scale * cls_ch_squeeze)

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x

至此就完成了骨干网络的定义,可通过 paddle.summary 结构可视化整个网络结构:

# 定义网络输入shape
IMAGE_SHAPE_C = 3
IMAGE_SHAPE_H = 32
IMAGE_SHAPE_W = 320


# 可视化网络结构
paddle.summary(MobileNetV3(),[(1, IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)])
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-1        [[1, 3, 32, 320]]     [1, 8, 16, 160]          216      
    BatchNorm-1      [[1, 8, 16, 160]]     [1, 8, 16, 160]          32       
   ConvBNLayer-1     [[1, 3, 32, 320]]     [1, 8, 16, 160]           0       
     Conv2D-2        [[1, 8, 16, 160]]     [1, 8, 16, 160]          64       
    BatchNorm-2      [[1, 8, 16, 160]]     [1, 8, 16, 160]          32       
   ConvBNLayer-2     [[1, 8, 16, 160]]     [1, 8, 16, 160]           0       
     Conv2D-3        [[1, 8, 16, 160]]     [1, 8, 16, 160]          72       
    BatchNorm-3      [[1, 8, 16, 160]]     [1, 8, 16, 160]          32       
   ConvBNLayer-3     [[1, 8, 16, 160]]     [1, 8, 16, 160]           0       
AdaptiveAvgPool2D-1  [[1, 8, 16, 160]]       [1, 8, 1, 1]            0       
     Conv2D-4          [[1, 8, 1, 1]]        [1, 2, 1, 1]           18       
     Conv2D-5          [[1, 2, 1, 1]]        [1, 8, 1, 1]           24       
    SEModule-1       [[1, 8, 16, 160]]     [1, 8, 16, 160]           0       
     Conv2D-6        [[1, 8, 16, 160]]     [1, 8, 16, 160]          64       
    BatchNorm-4      [[1, 8, 16, 160]]     [1, 8, 16, 160]          32       
   ConvBNLayer-4     [[1, 8, 16, 160]]     [1, 8, 16, 160]           0       
  ResidualUnit-1     [[1, 8, 16, 160]]     [1, 8, 16, 160]           0       
     Conv2D-7        [[1, 8, 16, 160]]     [1, 40, 16, 160]         320      
    BatchNorm-5      [[1, 40, 16, 160]]    [1, 40, 16, 160]         160      
   ConvBNLayer-5     [[1, 8, 16, 160]]     [1, 40, 16, 160]          0       
     Conv2D-8        [[1, 40, 16, 160]]    [1, 40, 8, 160]          360      
    BatchNorm-6      [[1, 40, 8, 160]]     [1, 40, 8, 160]          160      
   ConvBNLayer-6     [[1, 40, 16, 160]]    [1, 40, 8, 160]           0       
     Conv2D-9        [[1, 40, 8, 160]]     [1, 16, 8, 160]          640      
    BatchNorm-7      [[1, 16, 8, 160]]     [1, 16, 8, 160]          64       
   ConvBNLayer-7     [[1, 40, 8, 160]]     [1, 16, 8, 160]           0       
  ResidualUnit-2     [[1, 8, 16, 160]]     [1, 16, 8, 160]           0       
     Conv2D-10       [[1, 16, 8, 160]]     [1, 48, 8, 160]          768      
    BatchNorm-8      [[1, 48, 8, 160]]     [1, 48, 8, 160]          192      
   ConvBNLayer-8     [[1, 16, 8, 160]]     [1, 48, 8, 160]           0       
     Conv2D-11       [[1, 48, 8, 160]]     [1, 48, 8, 160]          432      
    BatchNorm-9      [[1, 48, 8, 160]]     [1, 48, 8, 160]          192      
   ConvBNLayer-9     [[1, 48, 8, 160]]     [1, 48, 8, 160]           0       
     Conv2D-12       [[1, 48, 8, 160]]     [1, 16, 8, 160]          768      
   BatchNorm-10      [[1, 16, 8, 160]]     [1, 16, 8, 160]          64       
  ConvBNLayer-10     [[1, 48, 8, 160]]     [1, 16, 8, 160]           0       
  ResidualUnit-3     [[1, 16, 8, 160]]     [1, 16, 8, 160]           0       
     Conv2D-13       [[1, 16, 8, 160]]     [1, 48, 8, 160]          768      
   BatchNorm-11      [[1, 48, 8, 160]]     [1, 48, 8, 160]          192      
  ConvBNLayer-11     [[1, 16, 8, 160]]     [1, 48, 8, 160]           0       
     Conv2D-14       [[1, 48, 8, 160]]     [1, 48, 4, 160]         1,200     
   BatchNorm-12      [[1, 48, 4, 160]]     [1, 48, 4, 160]          192      
  ConvBNLayer-12     [[1, 48, 8, 160]]     [1, 48, 4, 160]           0       
AdaptiveAvgPool2D-2  [[1, 48, 4, 160]]      [1, 48, 1, 1]            0       
     Conv2D-15        [[1, 48, 1, 1]]       [1, 12, 1, 1]           588      
     Conv2D-16        [[1, 12, 1, 1]]       [1, 48, 1, 1]           624      
    SEModule-2       [[1, 48, 4, 160]]     [1, 48, 4, 160]           0       
     Conv2D-17       [[1, 48, 4, 160]]     [1, 24, 4, 160]         1,152     
   BatchNorm-13      [[1, 24, 4, 160]]     [1, 24, 4, 160]          96       
  ConvBNLayer-13     [[1, 48, 4, 160]]     [1, 24, 4, 160]           0       
  ResidualUnit-4     [[1, 16, 8, 160]]     [1, 24, 4, 160]           0       
     Conv2D-18       [[1, 24, 4, 160]]     [1, 120, 4, 160]        2,880     
   BatchNorm-14      [[1, 120, 4, 160]]    [1, 120, 4, 160]         480      
  ConvBNLayer-14     [[1, 24, 4, 160]]     [1, 120, 4, 160]          0       
     Conv2D-19       [[1, 120, 4, 160]]    [1, 120, 4, 160]        3,000     
   BatchNorm-15      [[1, 120, 4, 160]]    [1, 120, 4, 160]         480      
  ConvBNLayer-15     [[1, 120, 4, 160]]    [1, 120, 4, 160]          0       
AdaptiveAvgPool2D-3  [[1, 120, 4, 160]]     [1, 120, 1, 1]           0       
     Conv2D-20        [[1, 120, 1, 1]]      [1, 30, 1, 1]          3,630     
     Conv2D-21        [[1, 30, 1, 1]]       [1, 120, 1, 1]         3,720     
    SEModule-3       [[1, 120, 4, 160]]    [1, 120, 4, 160]          0       
     Conv2D-22       [[1, 120, 4, 160]]    [1, 24, 4, 160]         2,880     
   BatchNorm-16      [[1, 24, 4, 160]]     [1, 24, 4, 160]          96       
  ConvBNLayer-16     [[1, 120, 4, 160]]    [1, 24, 4, 160]           0       
  ResidualUnit-5     [[1, 24, 4, 160]]     [1, 24, 4, 160]           0       
     Conv2D-23       [[1, 24, 4, 160]]     [1, 120, 4, 160]        2,880     
   BatchNorm-17      [[1, 120, 4, 160]]    [1, 120, 4, 160]         480      
  ConvBNLayer-17     [[1, 24, 4, 160]]     [1, 120, 4, 160]          0       
     Conv2D-24       [[1, 120, 4, 160]]    [1, 120, 4, 160]        3,000     
   BatchNorm-18      [[1, 120, 4, 160]]    [1, 120, 4, 160]         480      
  ConvBNLayer-18     [[1, 120, 4, 160]]    [1, 120, 4, 160]          0       
AdaptiveAvgPool2D-4  [[1, 120, 4, 160]]     [1, 120, 1, 1]           0       
     Conv2D-25        [[1, 120, 1, 1]]      [1, 30, 1, 1]          3,630     
     Conv2D-26        [[1, 30, 1, 1]]       [1, 120, 1, 1]         3,720     
    SEModule-4       [[1, 120, 4, 160]]    [1, 120, 4, 160]          0       
     Conv2D-27       [[1, 120, 4, 160]]    [1, 24, 4, 160]         2,880     
   BatchNorm-19      [[1, 24, 4, 160]]     [1, 24, 4, 160]          96       
  ConvBNLayer-19     [[1, 120, 4, 160]]    [1, 24, 4, 160]           0       
  ResidualUnit-6     [[1, 24, 4, 160]]     [1, 24, 4, 160]           0       
     Conv2D-28       [[1, 24, 4, 160]]     [1, 64, 4, 160]         1,536     
   BatchNorm-20      [[1, 64, 4, 160]]     [1, 64, 4, 160]          256      
  ConvBNLayer-20     [[1, 24, 4, 160]]     [1, 64, 4, 160]           0       
     Conv2D-29       [[1, 64, 4, 160]]     [1, 64, 4, 160]         1,600     
   BatchNorm-21      [[1, 64, 4, 160]]     [1, 64, 4, 160]          256      
  ConvBNLayer-21     [[1, 64, 4, 160]]     [1, 64, 4, 160]           0       
AdaptiveAvgPool2D-5  [[1, 64, 4, 160]]      [1, 64, 1, 1]            0       
     Conv2D-30        [[1, 64, 1, 1]]       [1, 16, 1, 1]          1,040     
     Conv2D-31        [[1, 16, 1, 1]]       [1, 64, 1, 1]          1,088     
    SEModule-5       [[1, 64, 4, 160]]     [1, 64, 4, 160]           0       
     Conv2D-32       [[1, 64, 4, 160]]     [1, 24, 4, 160]         1,536     
   BatchNorm-22      [[1, 24, 4, 160]]     [1, 24, 4, 160]          96       
  ConvBNLayer-22     [[1, 64, 4, 160]]     [1, 24, 4, 160]           0       
  ResidualUnit-7     [[1, 24, 4, 160]]     [1, 24, 4, 160]           0       
     Conv2D-33       [[1, 24, 4, 160]]     [1, 72, 4, 160]         1,728     
   BatchNorm-23      [[1, 72, 4, 160]]     [1, 72, 4, 160]          288      
  ConvBNLayer-23     [[1, 24, 4, 160]]     [1, 72, 4, 160]           0       
     Conv2D-34       [[1, 72, 4, 160]]     [1, 72, 4, 160]         1,800     
   BatchNorm-24      [[1, 72, 4, 160]]     [1, 72, 4, 160]          288      
  ConvBNLayer-24     [[1, 72, 4, 160]]     [1, 72, 4, 160]           0       
AdaptiveAvgPool2D-6  [[1, 72, 4, 160]]      [1, 72, 1, 1]            0       
     Conv2D-35        [[1, 72, 1, 1]]       [1, 18, 1, 1]          1,314     
     Conv2D-36        [[1, 18, 1, 1]]       [1, 72, 1, 1]          1,368     
    SEModule-6       [[1, 72, 4, 160]]     [1, 72, 4, 160]           0       
     Conv2D-37       [[1, 72, 4, 160]]     [1, 24, 4, 160]         1,728     
   BatchNorm-25      [[1, 24, 4, 160]]     [1, 24, 4, 160]          96       
  ConvBNLayer-25     [[1, 72, 4, 160]]     [1, 24, 4, 160]           0       
  ResidualUnit-8     [[1, 24, 4, 160]]     [1, 24, 4, 160]           0       
     Conv2D-38       [[1, 24, 4, 160]]     [1, 144, 4, 160]        3,456     
   BatchNorm-26      [[1, 144, 4, 160]]    [1, 144, 4, 160]         576      
  ConvBNLayer-26     [[1, 24, 4, 160]]     [1, 144, 4, 160]          0       
     Conv2D-39       [[1, 144, 4, 160]]    [1, 144, 2, 160]        3,600     
   BatchNorm-27      [[1, 144, 2, 160]]    [1, 144, 2, 160]         576      
  ConvBNLayer-27     [[1, 144, 4, 160]]    [1, 144, 2, 160]          0       
AdaptiveAvgPool2D-7  [[1, 144, 2, 160]]     [1, 144, 1, 1]           0       
     Conv2D-40        [[1, 144, 1, 1]]      [1, 36, 1, 1]          5,220     
     Conv2D-41        [[1, 36, 1, 1]]       [1, 144, 1, 1]         5,328     
    SEModule-7       [[1, 144, 2, 160]]    [1, 144, 2, 160]          0       
     Conv2D-42       [[1, 144, 2, 160]]    [1, 48, 2, 160]         6,912     
   BatchNorm-28      [[1, 48, 2, 160]]     [1, 48, 2, 160]          192      
  ConvBNLayer-28     [[1, 144, 2, 160]]    [1, 48, 2, 160]           0       
  ResidualUnit-9     [[1, 24, 4, 160]]     [1, 48, 2, 160]           0       
     Conv2D-43       [[1, 48, 2, 160]]     [1, 288, 2, 160]       13,824     
   BatchNorm-29      [[1, 288, 2, 160]]    [1, 288, 2, 160]        1,152     
  ConvBNLayer-29     [[1, 48, 2, 160]]     [1, 288, 2, 160]          0       
     Conv2D-44       [[1, 288, 2, 160]]    [1, 288, 2, 160]        7,200     
   BatchNorm-30      [[1, 288, 2, 160]]    [1, 288, 2, 160]        1,152     
  ConvBNLayer-30     [[1, 288, 2, 160]]    [1, 288, 2, 160]          0       
AdaptiveAvgPool2D-8  [[1, 288, 2, 160]]     [1, 288, 1, 1]           0       
     Conv2D-45        [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
     Conv2D-46        [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
    SEModule-8       [[1, 288, 2, 160]]    [1, 288, 2, 160]          0       
     Conv2D-47       [[1, 288, 2, 160]]    [1, 48, 2, 160]        13,824     
   BatchNorm-31      [[1, 48, 2, 160]]     [1, 48, 2, 160]          192      
  ConvBNLayer-31     [[1, 288, 2, 160]]    [1, 48, 2, 160]           0       
  ResidualUnit-10    [[1, 48, 2, 160]]     [1, 48, 2, 160]           0       
     Conv2D-48       [[1, 48, 2, 160]]     [1, 288, 2, 160]       13,824     
   BatchNorm-32      [[1, 288, 2, 160]]    [1, 288, 2, 160]        1,152     
  ConvBNLayer-32     [[1, 48, 2, 160]]     [1, 288, 2, 160]          0       
     Conv2D-49       [[1, 288, 2, 160]]    [1, 288, 2, 160]        7,200     
   BatchNorm-33      [[1, 288, 2, 160]]    [1, 288, 2, 160]        1,152     
  ConvBNLayer-33     [[1, 288, 2, 160]]    [1, 288, 2, 160]          0       
AdaptiveAvgPool2D-9  [[1, 288, 2, 160]]     [1, 288, 1, 1]           0       
     Conv2D-50        [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
     Conv2D-51        [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
    SEModule-9       [[1, 288, 2, 160]]    [1, 288, 2, 160]          0       
     Conv2D-52       [[1, 288, 2, 160]]    [1, 48, 2, 160]        13,824     
   BatchNorm-34      [[1, 48, 2, 160]]     [1, 48, 2, 160]          192      
  ConvBNLayer-34     [[1, 288, 2, 160]]    [1, 48, 2, 160]           0       
  ResidualUnit-11    [[1, 48, 2, 160]]     [1, 48, 2, 160]           0       
     Conv2D-53       [[1, 48, 2, 160]]     [1, 288, 2, 160]       13,824     
   BatchNorm-35      [[1, 288, 2, 160]]    [1, 288, 2, 160]        1,152     
  ConvBNLayer-35     [[1, 48, 2, 160]]     [1, 288, 2, 160]          0       
    MaxPool2D-1      [[1, 288, 2, 160]]    [1, 288, 1, 80]           0       
===============================================================================
Total params: 259,056
Trainable params: 246,736
Non-trainable params: 12,320
-------------------------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 44.38
Params size (MB): 0.99
Estimated Total Size (MB): 45.48
-------------------------------------------------------------------------------






{'total_params': 259056, 'trainable_params': 246736}
# 图片输入骨干网络
backbone = MobileNetV3()
# 将numpy数据转换为Tensor
input_data = paddle.to_tensor([padding_im])
# 骨干网络输出
feature = backbone(input_data)
# 查看feature map的纬度
print("backbone output:", feature.shape)
backbone output: [1, 288, 1, 80]
  • neck

neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征( 源码位置 ):

class Im2Seq(nn.Layer):
    def __init__(self, in_channels, **kwargs):
        """
        图像特征转换为序列特征
        :param in_channels: 输入通道数
        """ 
        super().__init__()
        self.out_channels = in_channels

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == 1
        x = x.squeeze(axis=2)
        x = x.transpose([0, 2, 1])  # (NWC)(batch, width, channels)
        return x

class EncoderWithRNN(nn.Layer):
    def __init__(self, in_channels, hidden_size):
        super(EncoderWithRNN, self).__init__()
        self.out_channels = hidden_size * 2
        self.lstm = nn.LSTM(
            in_channels, hidden_size, direction='bidirectional', num_layers=2)

    def forward(self, x):
        x, _ = self.lstm(x)
        return x


class SequenceEncoder(nn.Layer):
    def __init__(self, in_channels, hidden_size=48, **kwargs):
        """
        序列编码
        :param in_channels: 输入通道数
        :param hidden_size: 隐藏层size
        """ 
        super(SequenceEncoder, self).__init__()
        self.encoder_reshape = Im2Seq(in_channels)

        self.encoder = EncoderWithRNN(
            self.encoder_reshape.out_channels, hidden_size)
        self.out_channels = self.encoder.out_channels

    def forward(self, x):
        x = self.encoder_reshape(x)
        x = self.encoder(x)
        return x
neck = SequenceEncoder(in_channels=288)
sequence = neck(feature)
print("sequence shape:", sequence.shape)
sequence shape: [1, 80, 96]
  • head

预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布,本示例仅支持模型识别小写英文字母和数字(26+10)36个类别(源码位置):

class CTCHead(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 **kwargs):
        """
        CTC 预测层
        :param in_channels: 输入通道数
        :param out_channels: 输出通道数
        """ 
        super(CTCHead, self).__init__()
        self.fc = nn.Linear(
            in_channels,
            out_channels)

        # 思考:out_channels 应该等于多少?
        self.out_channels = out_channels

    def forward(self, x):
        predicts = self.fc(x)
        result = predicts

        if not self.training:
            predicts = F.softmax(predicts, axis=2)
            result = predicts

        return result

在网络随机初始化的情况下,输出结果是无序的,经过SoftMax之后,可以得到各时间步上的概率最大的预测结果,其中:pred_id 代表预测的标签ID,pre_scores 代表预测结果的置信度:

ctc_head = CTCHead(in_channels=96, out_channels=37)
predict = ctc_head(sequence)
print("predict shape:", predict.shape)
result = F.softmax(predict, axis=2)
pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print("pred_id:", pred_id)
print("pred_scores:", pred_socres)
predict shape: [1, 80, 37]
pred_id: Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,
       [[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])
pred_scores: Tensor(shape=[1, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[0.03683758, 0.03368053, 0.03604801, 0.03504696, 0.03696444, 0.03597261, 0.03925638, 0.03650934, 0.03873367, 0.03572492, 0.03543066, 0.03618268, 0.03805700, 0.03496549, 0.03329032, 0.03565763, 0.03846950, 0.03922413, 0.03970327, 0.03638541, 0.03572393, 0.03618102, 0.03565401, 0.03636984, 0.03691722, 0.03718850, 0.03623354, 0.03877943, 0.03731697, 0.03563465, 0.03447339, 0.03365586, 0.03312979, 0.03285240, 0.03273271, 0.03269565, 0.03269779, 0.03271412, 0.03273287, 0.03274929, 0.03276210, 0.03277146, 0.03277802, 0.03278249, 0.03278547, 0.03278742, 0.03278869, 0.03278949, 0.03279000, 0.03279032, 0.03279052, 0.03279064, 0.03279071, 0.03279077, 0.03279081, 0.03279087, 0.03279094, 0.03279106, 0.03279124, 0.03279152, 0.03279196, 0.03279264, 0.03279363, 0.03279509, 0.03279718, 0.03280006, 0.03280392, 0.03280888, 0.03281487, 0.03282148, 0.03282760, 0.03283087, 0.03282646, 0.03280647, 0.03275031, 0.03263619, 0.03242587, 0.03194289, 0.03122442, 0.02986610]])
  • 后处理

识别网络最终返回的结果是各个时间步上的最大索引值,最终期望的输出是对应的文字结果,因此CRNN的后处理是一个解码过程,主要逻辑如下:

def decode(text_index, text_prob=None, is_remove_duplicate=False):
    """ convert text-index into text-label. """
    character = "-0123456789abcdefghijklmnopqrstuvwxyz"
    result_list = []
    # 忽略tokens [0] 代表ctc中的blank位
    ignored_tokens = [0]
    batch_size = len(text_index)
    for batch_idx in range(batch_size):
        char_list = []
        conf_list = []
        for idx in range(len(text_index[batch_idx])):
            if text_index[batch_idx][idx] in ignored_tokens:
                continue
            # 合并blank之间相同的字符
            if is_remove_duplicate:
                # only for predict
                if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
                        batch_idx][idx]:
                    continue
            # 将解码结果存在char_list内
            char_list.append(character[int(text_index[batch_idx][
                idx])])
            # 记录置信度
            if text_prob is not None:
                conf_list.append(text_prob[batch_idx][idx])
            else:
                conf_list.append(1)
        text = ''.join(char_list)
        # 输出结果
        result_list.append((text, np.mean(conf_list)))
    return result_list

以 head 部分随机初始化预测出的结果为例,进行解码得到:

pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print(pred_id)
decode_out = decode(pred_id, pred_socres)
print("decode out:", decode_out)
Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,
       [[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])
decode out: [('mrmmmmmmmmmtttummmmmmmummmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm4', 0.034180813)]

小测试: 如果输入模型训练好的index,解码结果是否正确呢?

# 替换模型预测好的结果
right_pred_id = paddle.to_tensor([['xxxxxxxxxxxxx']])
tmp_scores = paddle.ones(shape=right_pred_id.shape)
out = decode(right_pred_id, tmp_scores)
print("out:",out)
out: [('pain', 1.0)]

上述步骤完成了网络的搭建,也实现了一个简单的前向预测过程。

没有经过训练的网络无法正确预测结果,因此需要定义损失函数、优化策略,将整个网络run起来,下面将详细介绍网络训练原理。

3. 训练原理详解

3.1 准备训练数据

PaddleOCR 支持两种数据格式:

  • lmdb 用于训练以lmdb格式存储的数据集(LMDBDataSet);
  • 通用数据 用于训练以文本文件存储的数据集(SimpleDataSet);

本次只介绍通用数据格式读取

训练数据的默认存储路径是 ./train_data, 执行以下命令解压数据:

!cd /home/aistudio/work/train_data/ && tar xf ic15_data.tar

解压完成后,训练图片都在同一个文件夹内,并有一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:

" 图像文件名         图像标注信息 "

train/word_1.png    Genaxis Theatre
train/word_2.png    [06]
...

注意: txt文件中默认将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。

数据集应有如下文件结构:

|-train_data
  |-ic15_data
    |- rec_gt_train.txt
    |- train
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...
    |- rec_gt_test.txt
    |- test
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...

确认配置文件中的数据路径是否正确,以 rec_icdar15_train.yml为例:

Train:
  dataset:
    name: SimpleDataSet
    # 训练数据根目录
    data_dir: ./train_data/ic15_data/
    # 训练数据标签
    label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]  # [3,32,320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True
    batch_size_per_card: 256
    drop_last: True
    num_workers: 8
    use_shared_memory: False

Eval:
  dataset:
    name: SimpleDataSet
    # 评估数据根目录
    data_dir: ./train_data/ic15_data
    # 评估数据标签
    label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 256
    num_workers: 4
    use_shared_memory: False

3.2 数据预处理

送入网络的训练数据,需要保证一个batch内维度一致,同时为了不同维度之间的特征在数值上有一定的比较性,需要对数据做统一尺度缩放归一化

为了增加模型的鲁棒性,抑制过拟合提升泛化性能,需要实现一定的数据增广

  • 缩放和归一化

第二节中已经介绍了相关内容,这是图片送入网络之前的最后一步操作。调用 resize_norm_img 完成图片缩放、padding和归一化。

  • 数据增广

PaddleOCR中实现了多种数据增广方式,如:颜色反转、随机切割、仿射变化、随机噪声等等,这里以简单的随机切割为例,更多增广方式可参考:rec_img_aug.py

def get_crop(image):
    """
    random crop
    """
    import random
    h, w, _ = image.shape
    top_min = 1
    top_max = 8
    top_crop = int(random.randint(top_min, top_max))
    top_crop = min(top_crop, h - 1)
    crop_img = image.copy()
    ratio = random.randint(0, 1)
    if ratio:
        crop_img = crop_img[top_crop:h, :, :]
    else:
        crop_img = crop_img[0:h - top_crop, :, :]
    return crop_img
# 读图
raw_img = cv2.imread("/home/aistudio/work/word_1.png")
plt.figure()
plt.subplot(2,1,1)
# 可视化原图
plt.imshow(raw_img)
# 随机切割
crop_img = get_crop(raw_img)
plt.subplot(2,1,2)
# 可视化增广图
plt.imshow(crop_img)
plt.show()

image.png

3.3 训练主程序

模型训练的入口代码是 train.py,它展示了训练中所需的各个模块: build dataloader, build post process, build model , build loss, build optim, build metric,将各部分串联后即可开始训练:

  • 构建 dataloader

训练模型需要将数据组成指定数目的 batch ,并在训练过程中依次 yield 出来,本例中调用了 PaddleOCR 中实现的 SimpleDataSet

基于原始代码稍作修改,其返回单条数据的主要逻辑如下

def __getitem__(data_line, data_dir):
    import os
    mode = "train"
    delimiter = '\t'
    try:
        substr = data_line.strip("\n").split(delimiter)
        file_name = substr[0]
        label = substr[1]
        img_path = os.path.join(data_dir, file_name)
        data = {'img_path': img_path, 'label': label}
        if not os.path.exists(img_path):
            raise Exception("{} does not exist!".format(img_path))
        with open(data['img_path'], 'rb') as f:
            img = f.read()
            data['image'] = img
        # 预处理操作,先注释掉
        # outs = transform(data, self.ops)
        outs = data
    except Exception as e:
        print("When parsing line {}, error happened with msg: {}".format(
                data_line, e))
        outs = None
    return outs

假设当前输入的标签为 train/word_1.png Genaxis Theatre, 训练数据的路径为 /home/aistudio/work/train_data/ic15_data/, 解析出的结果是一个字典,里面包含 img_path label image 三个字段:

data_line = "train/word_1.png    Genaxis Theatre"
data_dir = "/home/aistudio/work/train_data/ic15_data/"

item = __getitem__(data_line, data_dir)
print(item)
{'img_path': '/home/aistudio/work/train_data/ic15_data/train/word_1.png', 'label': 'Genaxis Theatre', 'image': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00Y\x00\x00\x00\x0e\x08\x02\x00\x00\x00\xcb\xe2\'\xb7\x00\x00\x00\x01sRGB\x00\xae\xce\x1c\xe9\x00\x00\x00\x04gAMA\x00\x00\xb1\x8f\x0b\xfca\x05\x00\x00\x00 cHRM\x00\x00z&\x00\x00\x80\x84\x00\x00\xfa\x00\x00\x00\x80\xe8\x00\x00u0\x00\x00\xea`\x00\x00:\x98\x00\x00\x17p\x9c\xbaQ<\x00\x00\x0bmIDATHK\x8d\x96\xf9S[\xd7\x15\x80\x01\xa7\x93\xa4\xfd1\x99L\xea\x80\xc4\xa2]B\x0bb\xdf\x84\x04\x18\x8c\x01\xb3\x8aE\xec\x12\x02\t\xb4KhC\xfb\xbe=\xed\xbb\x04\xc2l&N\xd2\xb4\x93i\x9bv\xa6\x7fL\xdb\xe9d\xe2N\xd3d<u6C\x8f\xc0I\x9c\xb1\x7f\xc8\x9d\x8f3\xf7\xdd\xf7\xee\x9d{\xbfw\xceC\x95\xd3\xc3\xe3\x04\x02\xe1\xad\xb7\xde\xbazv\xf9\xcd7\xdf<~\xfc\xf8\xef\xff\xfc\xc7\xe7\x9f\xff\xfb\x7fO\x9f>y\xf2\xe4\xeb\xaf\xbf\xbe\xbc\xbc|\xedWU\xaf]\xb7\xab\xab\xab\xca\xca\xab\x8a\x8a\x8a\xca\xca\xca\xeb\x08\x97\xb7\xa0sY\xf1\xda\xf7\xdf?\x83vUYQUU\x05wa\xd6w\xcf\x9e\xdd\xbau\xeb\xf5\xd7_G\xa1Po\xfe\xe6\xd7\x9f}\xf6\x19,\x08+|\xf9\xe5\x97U\xe5\xd9\xcf\xdb\x1bo\xbc\xf1\xf6\xdbo\xdf\xbe}\x1b\xf6\xf0\xce;\xef\xc0"\xdf~\xfb\xedw\xdf}\xf7\xf4\xe9\xd3\xaf\xbe\xfa\xea\x8b/\xbex\xf2\xe4\xbf\xffy\xfc\xaf[\x95\xcf`eh\xcf\xa7]\x96\xb7\x01\x8f\x95\xb7q\xf5\xe3~*\xe16l\xa6\xb2\xf2\xcd\xab\x8a[\xb0\rh\xd7\xfb)?\x00\rV\xae\xac*O\xfc\xe1\x08\x97\x15\x15U\x95W\x97\xf0W\x05\xf3LJ]&\x14\x07r\xa1D\xdc\x1d\xd2\x8a\x15\x0b\xe3\\J-\x89\x88&\xe0\xaa\xb1\xb8j\x0c\xbe\x06K@\xe1\x08h\x0c\x1e\xd5@@\xd7\x03xT\x1d\x1e\x85\xc6\xd5\xa0p\xa8jlu\r\xb6\x1aM\xc1R\xb0\xb5x\xf4\xed\xba\x9aw\xd1\x10\xd1\xd5\r5\xb7\xeb\xde{\x17M\xc0R\x18\xf4\xd6\xd9\x99\xc5\x95UAw\x0f\xa7\x91\xca\xa4\xd2\x9a\xd1\xb5\xd8\xea\xf7jQ\xa82ht\x1d\x99\xdc8<<\xa2P\xa8\xc2\xe1\xe8\xe9\xe9\xf9\xe1\xe1Q:\x9d\x8dF\xe3\x1e\x8f\x0f\x06gg\xe7\xfazY4\n\x9eN\xc6\xd2H\x18\x80Jl\x00h\x84z\x80\xd4P\x03P\x1aP@#\x06\rP\x1b\x80\xba\x16\x12\x9d\x81\xa7Q\xeaI\xb0\xf3\xf2\xfe\xd1\x18r=\xbe\x11C\xa4bIT,\x81\x86#\xde@\xc7\x13\xe8x\x12\x03Oh&\x90\x80\n\x97\xc1\x1a\xb0\xba5b\xb9f[\xa6\x12J\x16F\xa79m\xacNz\x07\xa7\xbd\x7f|\xe0>orq}nmy\x9a73256p\xf7\x1egh\x84=0\xd4\xcb\x1e\xe8\xeaf\xb7w\xf4\xb40\xbb\x98\x8c\xf6\xa6\xa6\xb1\xe1\xf1;\xec\xe1\xde\x8e\xbe\xae\xd6\x9e\xaevVO\'\xbb\xbb\xa3\xaf\xb3\xad\xb7\xbd\xad\xa7\x8f5\xb8\xc1\xdf\xde\x16I9\xfd\xc3\xadm\xdd}\xec;\x10{{8,\x16\xbb\xb3\xb3\x9b\xc1`R\xa9\xf4\xbe>\xce\xc6\x86\xc0d\xb2 H\x04\x14X\xadv\x8b\xc5\x06\x972\x99\x82\xcb\x9d\xe7@c\xc3\x84vVw[w\x1b\xa3\x9dIke\x90\x9bi\xe4\x16*\x81\x8ak\xa0\xe2\xeb\xe88\x0c\x1d_\x0f\x91Ah`\xe0\xb1\x0c<N"\xd8\x12\xf0\xd6f\xee\x8d\x0fv\xb3:\xe8L&\xa9\xb1\x89H\x02\xa8X\\\x19\x1c\x86\x86\xc7\x02t\x02\xee\x06&\x11\\P*\xfc\x16\x97\xdbh\xdb\\\\\x1b\xeb\x1bj\'7\xb5\x91\x18c\x9c\x11\xadDg\xd5\xb9\x10w<\x1d\xce\x05\x1c\x88tK&\x11J\xbd6o\xc8\x13\xc8\xc4\x12>\x87K\xabP\xcdO\xcd\x88\xf8|\xbdJ\x15\xf2\xf8b\xa1\x84\xdb\xea\x97l)\xf6\xf7\xac\xd2\x1d\xb5\x90\xbf;9>\'X\x17\xabU\xc6M\xc1\xae\\\xae\xdb\xda\x92\t\x85R\x91Ha\xb7\xfbuz\xb3i\xdff1;dR\xd5\xf8\xd8T#\x85\x81\xaa\xa9\xef\xec\xe8\x15\x8b\xa4>oH.S//\xadK%J$\x14\x8b\xc7\xd20\xe2\xf3\x05\x90P`\x89\xc7\x9d\x9d\x990\x194\xbb\xdb\x9bF\x9d\xca\xa8U\x8b\x04k3\xf7GW\x17\xb9\x06\xa5\xdc\xaa\xd3l\xaf\xadl,\xce9\x8c\xfa\xb3\x83\xc2I\xbe\xe0w:\xb5J\x99d{sG\xc8WJ\xc4R\xd1&\x8f;\x05S\xb4J\xa9A#S\xcbD"\xc1\xca\xf4\xf8p?\xab\xf5No\x17\xb8`\x82\x0b\xcf\xbe\x1d\\\xec\xac\t\xf9\xdc\xe5\xd5\xa9\x85\xe5\x899!\x8f\xef4\x94E\xe4\xe3\xa5\x93\xc2E.V4i\xacz\xc5~,\x10?H\x1f\x94\xb2\xa5\\"\x93\x8a$T\x12\x85Eo\n\xfb\x02\xa5\xdc\xc1\xf1\xc1\xe9\xc9\xc1C\xc4\x1b\x8b\x87\xd2\x01W8\xe0\x8d\x1a\xb4V\x9b\xd9c6:\x95\n\xbd\xd3\x110\x1a\x1dz\xbdM\xa7\xb3\xc6\xe3\xf9P(\x19\xf0G\xcc&\xfb\x96pgrb\xb6\xa5\xb9\xa3\xa6\xba\x8eAoY[\x15\xd8mn\x10\x04.\x14rM&]8*\x9d\xa6S\xf9L&\x97\xcf\xe6\xe42\x89T\xb2\x93I\'\x11\xbf\xaf\x90Ig\x12q\xa7\xd5bP\xab=6[>\x99<+\x1e\xc4\xfd\x01\x9dL\xae\x95\xca\x02\x0e\xdb\x87\x0fO\x83\x1e\xe7\xbeV\xe5uX\xd21$\x9f\x8a\xa5\xa2\x81\xa0\xc7\x96\x8e\x05\x8f\n\xa9\xd3\xa3\xccI)\x9dM\x06\xf7\x94\xa2\xd9\xc9;c\xc3\x1cH\x19&\x81Zv\x818|z\xa9\xda\xb1g\x8a\xb9\x82~\x93\xd3\xa2\xde\xcf"\x85\x83\xd4\xe9i\xf1\xfd\x8b\xa3\x8f\x8e\xb2g\x01G$\xe8\x8c\x9e\x1f\xbe\x0f9\xb2\xaf19M\xae|\xa2`T\x1b\xdd\x16w.\x9e;/\x9d?:\xfe\xe8\xcf\x1f\xff\xf5({\x02:\x82\xee\xc8\x83\xc2Y!UJ\xc5\x0b\xfbz\xbbVm\x8a\x84R~_\xd4a\xf7[\xcc\xee|\xeeA6S\x8aG\xd2&\x83M\xb0\xbe=39\x0f\xd5\x04\xdf\x17\n\x91\xbe\xc0]v\xd9}\n\xa9fyq].Q\xa7\xe2\xb9D4c\xb7\xb8\xddN_<\x9eT\xa9T:\x9d\xee\xe4\xe4\xa4\x98/\\\x9c?<:,\x85C\x88`m]\xbe+1\x1b\x8c\xc9p\xf4(_\x0c\xb8<6\xa3)\x1b\x8f?:;\xd6\xab\xe5\xfc\x95E\xbbI\x9fKF\x11\x9f\xcb\xb8\'\xd7\xa9$`$\x9b@\x8a\xd9\xc8\xc5i\xe1\xfc8g7kxs\xa3\xf7G\x06\x9a\x88\x94&\xe2\xb5\x8b\xa0\xcd\xa3\x10\xee\x82\x8b\x07\xc9\xc2a,\x8b8\x02\xb9p1\x1b-%C\xb9\x98?\xed\xb7\x87M\x1a\xbb\xd7\x1a:?\xfc\xc0m\x0e\xac/\x08w\x05\x8a\\\xec\xd0at\x07\x9d\xe1\xc3\xf4\x83\x8b\xa3\x0f\xc0Q1Yr\x9b}|\x1e\xdc\x95\xa5"9H\x93b\xfaH.V\xef\x8a\x94>w\xd8i\xf3C\xa6\xc8eZ$\x90\x005\xf1H\xd6\xebB\xd4R\xdd\x02w\xb5\xa7\x8d]\x87\xc2\xd3\x88\xcc\xd9)\x9e\xdf\x89@\x89-L/+\xa4\xdad8\x0bSV\x17\xf8\xbb"y*\x99\xb3Y]n\x97\xff\xe2\xe1\x87\xb9L\xb1\x98/E\x90\xb8\xdd\xea\xda\\\x13*%*\x9dJ\x1f\xf6G\xe0\x95 \x1eD\xaf\xd4G\xfc\xc8\xe9\xe1!\x94\xc6\xd4\xe8\xa8A\xa3LE\x11\x9dJ>?=\xb18;Y\xcc$R\xd1P4\xe8\xcd%\xc3\x80F.\x86\xa4`w\xb7\x97]\x10\xe8\x15P ^\x93C\xbc\xbai\x94j\n\xe1\x14\xb8\x88{#p\xfe\xa0;\xe1\xb1\x85\xad\x06\xafFf\x12\xf1\x15*\x891\x16\xcc\x01K\\\xc1\xea\xfcv"Tt\xec\x07=\xd6H,\x90/$\x8e\xddF\x1f\xa4\x92\xcf\x1c\x14\xf2D:\x89\xf1A\xee\xfc\xd3\xdf\xff\xed\x93\x0f>U\x8a\xb5\xbbB\xa5\xc7\x16\xb2\xef{\xb5\n\x93xK\xe9s\x84C\x9ex1}|z\xf8\x08D\x0bV\xc4\xdd\xcdl\\-\xa5\x99\xd21yo\x0e\xb2O\xb1\xa3]\x9e\xdd\xb0\x1a\xdcg\x07\x8fLZ\x07\xf4\x8d\x1a\x1b\xa4R \x18\xf5\x07"\xa5\xa33\x9f\x17q\xd8\xbd*\xa5nG,\x97\x88\x15\x0e\xab\'\x12\x88\x9d\x1c\x9e}\xf2\xbb?F\x83q\xc9\xb6L\xab\xd4\xe7S\x05\xa9H2?\xbd`1\x98\x0f\xb2E\x93n\x7f~zne\x91\x97\x8a\xc5\xa3\xa1`\xd0\xebID\xc2\xf10\xa2\x94JF\x87\x87\xd8=\xddt\x12\x95N\xa4W\x98U\xfa\x98\'\x04:R\xbeH)\x9eCl\xde}\xa8\xf0}_\xc4\x9fI\x86\x8ba_\xdaf\xf4I\xb64\n\xb1\xc1g\x8f\x85\xdc\xa9\xb9\x89\x8d\xa9{\xcb\x1eK\xc2\xbc\x170i\xfcvC8\xee/\xfa-\xe1\x847\x15\xb4"R\xbe\xc2\xa4\xb4\x06m\xe1|\xe4 \x85\xe4\x1dF\xafFb\x00\xb3N\xb3\x1f\x8e\x04\xe7\x0c\xfb\x92\xf1P\xf6A\xee\xe2\xe3\x87\x7f\xca\xc5\x8e\xf8<q;\x9dE\xa8\xa51\xc9\x9dC\xac\xf1\x80#f\xd28E\xebr\xb79\xf4\xe8\xf8\x0f.Spsy\xd7k\x8f\x04|\t\x87#`\xb3\x052 \xc5\x97\x08\xf8cF\xbdS\xbbgVH\xf5\x90_\xd9\xf8\xe1\xc5\xf1\x87\x9f|\xf4\x97R\xee\x04\xbc\xfb\x1c\xa1B\xea\x01D\xad\xc2\x08\x11\xd2\xd3\xa4\xb3p\'\xe7\xe7\xa6\x16\x82\x1e$\x19I\xe5\x92\xf9B:\x0f\x1d\x95T9~\xf7>\xab\xb3\x0fD\x94]\xf0\xe7W@\x07\xfcg\x85\x8f\x85\x07jxn\xe5\xfe\xc0\xa8`I\xa4W\xd9 /\xdcVD+\xb7n\xf0v\xf8KR\xa3\xda\xad\x91Z\xc7\x06\x17\xef\xf4Nk$N\xd9\x96y{M+\xde\xd0{\xcc\xf1<r\x98\xf4\xa4\xc2\xb6\x88^bth]J\xa1Z\xb7k\x00)F\xb9\x99\xcf\xdb\x96okv\xf8\xf2\xad\xd5]\xa8/\x8b\xd6\t\x82\xc2\xceD\xc4\x95\xda\xdb\xd9\x1f\xe3L7\xe1\xdb\xf1\xd5\x8d4L\x0b\xbbuH\xb5\xad7\xc8\xac\xeb\xdc-\xc1\x82\xd8\xa9\xf7K\xf9\xea\xa9\xa1\xf9]\x81J.5\x8aD*\xa1P\xe9r\x85\x83\xc1L:}\xe4vF\xadf\x9flGo\xd9\xf7\x04\xbdP\x07\xf0\x1d(@\xc6\x99\xb4.\xb9Xg7\x06"\xbe\xac\xcb\x8c\x98\xb5\x1e\xb7%,\xdd\xde\x1b\x19\x98\xba\xd37n\xd1{R\x91\xc2a\xf6\x14\n\x10R\x15\n\xb0\xbb\x95\xd3\x88g\xd0\x89e*\xfa\xdb{GX\x83\xa3\xec\xa1\xf9\xb1\xe9\xf1\xfe\xbb\xad\xc4&\x1a\x86B\xc1\xd0\x06zG\xb8\x13K\xc0\xdd\xfe\xc9\x0e&\xbb\xafs\x987#\x98\x1cYl\xa1\xf6\x911\xad\x93#KC\xac\x99v\xfa`+\xb5\x7fiZ\xb8\xb5\xb4;3<3\xce\x19\xe7\xb4rF9c4,\xbd\x83\xd6y\x8f=\xba8\xc1\xebnf\rt\x0f\xf5\xb5\xf7\x03=-}K3\xab\xabs\x02\xf9\xa6fcn\xebN\xd7=\x06\xae\x95\x84\xa6a~K\x84\xd8\xcb\xec\x1f\xee\x19[\xe7\n\x07;G`|bpv\x94=\xd9L\xec\x18\xec\x19\xbd\xd37\xda\xd5\xde\xcf\xa4wq\xa7\x97\x84|\xa9Fi\\\x99\x17L\x8dr\xe9\xa4\xe6\x81\xde\xe1\x15\xee\xfa\xd6\xaa\x18>U\xcb\xb3kC\xac\x11\x1c\x8a\xdc\xc9\xe0\xac\xcf\xef\x00Sw\x17a{\xd3#\xbc\x0e:\xbb\x85\xd2C\xc5\xb6\xc2\x88pE\xb2\xbe\xb0=w\x7f\x052\xb1\x95\xdaC\'\xb44\x11\x98M\x04F\x05\xa9\x96\xf0"d4\x11.\xf1\xb5\xf8k\xa0\xf3\x02h\x12\xbe\x0c\xe5\x07\xa8x\xf4\r\x14\x1c\x8a\x88C\xe3_\t\xfc\x1e}\x19b\x1d\x85P\xdb\x08\xdc,u\xd3\'\xd6Qo:7\xfd\x1bH\xf5\xf0\xf3\xb1\x91\x8e\xa1\xd2\xb1\x8d\xd7\x90_\x8c\x0c\x1c\x05h\xc2\x91^\x84\x81\xa3\xd21\xf0F[\xaei\xfe!B\xa7\x99\x8emy\x15\xcc&\\S\x13\x8e^A\xa8\xfb\xd1\x05Xx\xce\xab]\xfc\xe8\xe5\xb9\x94\x1b5e0h\x1c\x06\x8d}9^[\xc0\xbd\x14\xf1\x84:"X~n\xf0\xda#,\x02\x83\x84Z2\xb1\x9e\x04\xa6 \x92\xea\x1bI\rd\x88\x94z2\x03\xd3\xc8\xc4\x90\x99\x18\xe2\xab\xc031?\x83\x81%2\xb04\x06\x96\xc1\xc02\x7fadb\x9b\x98/\xba\x80\x8c\xb8\x01\x8c\xfc\xfc\xfd\xdf$\xc2O\'\x7f!5\x9e\xdf\xc2\xa0\x08e\x11/E,\x9aX\x16\xf1R,\x8b@\xe3\xb0(\x00{\x13\xcb:j\xf1e\x11\xf5\xc4\xb2\x88\x06\x80L\xc6\x00\x94\xc6\x06\xc8\x05\xd0\xf1\n\xe8\r\xa4\x17\xf9\xe9\x99\xeb|\x81\x04\xf9e\x91\n"\x80\xff\x03\x99\xa0+\x94\xbd\xf0X\xa1\x00\x00\x00\x00IEND\xaeB`\x82'}

实现完单条数据返回逻辑后,调用 padde.io.Dataloader 即可把数据组合成batch,具体可参考 build_dataloader

  • build model
    build model 即搭建主要网络结构,具体细节如《2.3 代码实现》所述,本节不做过多介绍,各模块代码可参考modeling
  • build loss
    CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可:
import paddle.nn as nn
class CTCLoss(nn.Layer):
    def __init__(self, use_focal_loss=False, **kwargs):
        super(CTCLoss, self).__init__()
        # blank 是 ctc 的无意义连接符
        self.loss_func = nn.CTCLoss(blank=0, reduction='none')

    def forward(self, predicts, batch):
        if isinstance(predicts, (list, tuple)):
            predicts = predicts[-1]
        # 转置模型 head 层的预测结果,沿channel层排列
        predicts = predicts.transpose((1, 0, 2)) #[80,1,37]
        N, B, _ = predicts.shape
        preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
        labels = batch[1].astype("int32")
        label_lengths = batch[2].astype('int64')
        # 计算损失函数
        loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
        loss = loss.mean()
        return {'loss': loss}
  • build post process

具体细节同样在《2.3 代码实现》有详细介绍,实现逻辑与之前一致。

  • build optim

优化器使用 Adam , 同样调用飞桨API: paddle.optimizer.Adam

  • build metric

metric 部分用于计算模型指标,PaddleOCR的文本识别中,将整句预测正确判断为预测正确,因此准确率计算主要逻辑如下:

def metric(preds, labels):    
    correct_num = 0
    all_num = 0
    norm_edit_dis = 0.0
    for (pred), (target) in zip(preds, labels):
        pred = pred.replace(" ", "")
        target = target.replace(" ", "")
        if pred == target:
            correct_num += 1
        all_num += 1
    correct_num += correct_num
    all_num += all_num
    return {
        'acc': correct_num / all_num,
    }
preds = ["aaa", "bbb", "ccc", "123", "456"]
labels = ["aaa", "bbb", "ddd", "123", "444"]
acc = metric(preds, labels)
print("acc:", acc)
# 五个预测结果中,完全正确的有3个,因此准确率应为0.6
acc: {'acc': 0.6}

将以上各部分组合起来,即是完整的训练流程:

def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()

    global_config = config['Global']

    # build dataloader
    train_dataloader = build_dataloader(config, 'Train', device, logger)
    if len(train_dataloader) == 0:
        logger.error(
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
        )
        return

    if config['Eval']:
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
    else:
        valid_dataloader = None

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

    # build loss
    loss_class = build_loss(config['Loss'])

    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
        epochs=config['Global']['epoch_num'],
        step_each_epoch=len(train_dataloader),
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer)
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))

    use_amp = config["Global"].get("use_amp", False)
    if use_amp:
        AMP_RELATED_FLAGS_SETTING = {
            'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
            'FLAGS_max_inplace_grad_add': 8,
        }
        paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
        scale_loss = config["Global"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["Global"].get(
            "use_dynamic_loss_scaling", False)
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
    else:
        scaler = None

    # start train
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer, scaler)

4. 完整训练任务

4.1 启动训练

PaddleOCR 识别任务与检测任务类似,是通过配置文件传输参数的。

要进行完整的模型训练,首先需要下载整个项目并安装相关依赖:

# 克隆PaddleOCR代码
# !git clone https://gitee.com/paddlepaddle/PaddleOCR
# 修改代码运行的默认目录为 /home/aistudio/PaddleOCR
import os
os.chdir("/home/aistudio/PaddleOCR")
# 安装PaddleOCR第三方依赖
!pip install -r requirements.txt
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)
Collecting scikit-image==0.17.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d7/ee/753ea56fda5bc2a5516a1becb631bf5ada593a2dd44f21971a13a762d4db/scikit_image-0.17.2-cp37-cp37m-manylinux1_x86_64.whl (12.5 MB)
     |████████████████████████████████| 12.5 MB 13.3 MB/s            
[?25hRequirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)
Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)
Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.36.1)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)
Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)
Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)
Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)
Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)
Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)
Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)
Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)
Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)
Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)
Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)
Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)
Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)
Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)
Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)
Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)
Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)
Installing collected packages: scikit-image
  Attempting uninstall: scikit-image
    Found existing installation: scikit-image 0.19.1
    Uninstalling scikit-image-0.19.1:
      Successfully uninstalled scikit-image-0.19.1
Successfully installed scikit-image-0.17.2

创建软链,将训练数据放在PaddleOCR项目下:

!ln -s /home/aistudio/work/train_data/ /home/aistudio/PaddleOCR/

下载预训练模型:

为了加快收敛速度,建议下载训练好的模型在 icdar2015 数据上进行 finetune

!cd PaddleOCR/
# 下载MobileNetV3的预训练模型
!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 解压模型参数
!tar -xf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar
--2021-12-22 15:39:39--  https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 51200000 (49M) [application/x-tar]
Saving to: ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’

rec_mv3_none_bilstm 100%[===================>]  48.83M  15.5MB/s    in 3.6s    

2021-12-22 15:39:42 (13.7 MB/s) - ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’ saved [51200000/51200000]

启动训练命令很简单,指定好配置文件即可。另外在命令行中可以通过 -o 修改配置文件中的参数值。启动训练命令如下所示

其中:

  • Global.pretrained_model: 加载的预训练模型路径
  • Global.character_dict_path : 字典路径(这里只支持26个小写字母+数字)
  • Global.eval_batch_step : 评估频率
  • Global.epoch_num: 总训练轮数
!python3 tools/train.py -c configs/rec/rec_icdar15_train.yml \
   -o Global.pretrained_model=rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy \
   Global.character_dict_path=ppocr/utils/ic15_dict.txt \
   Global.eval_batch_step=[0,200] \
   Global.epoch_num=40
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
[2021/12/23 19:28:52] root INFO: Architecture : 
[2021/12/23 19:28:52] root INFO:     Backbone : 
[2021/12/23 19:28:52] root INFO:         model_name : large
[2021/12/23 19:28:52] root INFO:         name : MobileNetV3
[2021/12/23 19:28:52] root INFO:         scale : 0.5
[2021/12/23 19:28:52] root INFO:     Head : 
[2021/12/23 19:28:52] root INFO:         fc_decay : 0
[2021/12/23 19:28:52] root INFO:         name : CTCHead
[2021/12/23 19:28:52] root INFO:     Neck : 
[2021/12/23 19:28:52] root INFO:         encoder_type : rnn
[2021/12/23 19:28:52] root INFO:         hidden_size : 96
[2021/12/23 19:28:52] root INFO:         name : SequenceEncoder
[2021/12/23 19:28:52] root INFO:     Transform : None
[2021/12/23 19:28:52] root INFO:     algorithm : CRNN
[2021/12/23 19:28:52] root INFO:     model_type : rec
[2021/12/23 19:28:52] root INFO: Eval : 
[2021/12/23 19:28:52] root INFO:     dataset : 
[2021/12/23 19:28:52] root INFO:         data_dir : ./train_data/ic15_data
[2021/12/23 19:28:52] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 19:28:52] root INFO:         name : SimpleDataSet
[2021/12/23 19:28:52] root INFO:         transforms : 
[2021/12/23 19:28:52] root INFO:             DecodeImage : 
[2021/12/23 19:28:52] root INFO:                 channel_first : False
[2021/12/23 19:28:52] root INFO:                 img_mode : BGR
[2021/12/23 19:28:52] root INFO:             CTCLabelEncode : None
[2021/12/23 19:28:52] root INFO:             RecResizeImg : 
[2021/12/23 19:28:52] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 19:28:52] root INFO:             KeepKeys : 
[2021/12/23 19:28:52] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 19:28:52] root INFO:     loader : 
[2021/12/23 19:28:52] root INFO:         batch_size_per_card : 256
[2021/12/23 19:28:52] root INFO:         drop_last : False
[2021/12/23 19:28:52] root INFO:         num_workers : 4
[2021/12/23 19:28:52] root INFO:         shuffle : False
[2021/12/23 19:28:52] root INFO:         use_shared_memory : False
[2021/12/23 19:28:52] root INFO: Global : 
[2021/12/23 19:28:52] root INFO:     cal_metric_during_train : True
[2021/12/23 19:28:52] root INFO:     character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 19:28:52] root INFO:     character_type : EN
[2021/12/23 19:28:52] root INFO:     checkpoints : None
[2021/12/23 19:28:52] root INFO:     debug : False
[2021/12/23 19:28:52] root INFO:     distributed : False
[2021/12/23 19:28:52] root INFO:     epoch_num : 40
[2021/12/23 19:28:52] root INFO:     eval_batch_step : [0, 200]
[2021/12/23 19:28:52] root INFO:     infer_img : doc/imgs_words_en/word_19.png
[2021/12/23 19:28:52] root INFO:     infer_mode : False
[2021/12/23 19:28:52] root INFO:     log_smooth_window : 20
[2021/12/23 19:28:52] root INFO:     max_text_length : 25
[2021/12/23 19:28:52] root INFO:     pretrained_model : rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
[2021/12/23 19:28:52] root INFO:     print_batch_step : 10
[2021/12/23 19:28:52] root INFO:     save_epoch_step : 3
[2021/12/23 19:28:52] root INFO:     save_inference_dir : ./
[2021/12/23 19:28:52] root INFO:     save_model_dir : ./output/rec/ic15/
[2021/12/23 19:28:52] root INFO:     save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 19:28:52] root INFO:     use_gpu : True
[2021/12/23 19:28:52] root INFO:     use_space_char : False
[2021/12/23 19:28:52] root INFO:     use_visualdl : False
[2021/12/23 19:28:52] root INFO: Loss : 
[2021/12/23 19:28:52] root INFO:     name : CTCLoss
[2021/12/23 19:28:52] root INFO: Metric : 
[2021/12/23 19:28:52] root INFO:     main_indicator : acc
[2021/12/23 19:28:52] root INFO:     name : RecMetric
[2021/12/23 19:28:52] root INFO: Optimizer : 
[2021/12/23 19:28:52] root INFO:     beta1 : 0.9
[2021/12/23 19:28:52] root INFO:     beta2 : 0.999
[2021/12/23 19:28:52] root INFO:     lr : 
[2021/12/23 19:28:52] root INFO:         learning_rate : 0.0005
[2021/12/23 19:28:52] root INFO:     name : Adam
[2021/12/23 19:28:52] root INFO:     regularizer : 
[2021/12/23 19:28:52] root INFO:         factor : 0
[2021/12/23 19:28:52] root INFO:         name : L2
[2021/12/23 19:28:52] root INFO: PostProcess : 
[2021/12/23 19:28:52] root INFO:     name : CTCLabelDecode
[2021/12/23 19:28:52] root INFO: Train : 
[2021/12/23 19:28:52] root INFO:     dataset : 
[2021/12/23 19:28:52] root INFO:         data_dir : ./train_data/ic15_data/
[2021/12/23 19:28:52] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:28:52] root INFO:         name : SimpleDataSet
[2021/12/23 19:28:52] root INFO:         transforms : 
[2021/12/23 19:28:52] root INFO:             DecodeImage : 
[2021/12/23 19:28:52] root INFO:                 channel_first : False
[2021/12/23 19:28:52] root INFO:                 img_mode : BGR
[2021/12/23 19:28:52] root INFO:             CTCLabelEncode : None
[2021/12/23 19:28:52] root INFO:             RecResizeImg : 
[2021/12/23 19:28:52] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 19:28:52] root INFO:             KeepKeys : 
[2021/12/23 19:28:52] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 19:28:52] root INFO:     loader : 
[2021/12/23 19:28:52] root INFO:         batch_size_per_card : 256
[2021/12/23 19:28:52] root INFO:         drop_last : True
[2021/12/23 19:28:52] root INFO:         num_workers : 8
[2021/12/23 19:28:52] root INFO:         shuffle : True
[2021/12/23 19:28:52] root INFO:         use_shared_memory : False
[2021/12/23 19:28:52] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
[2021/12/23 19:28:52] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:28:52] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
W1223 19:28:52.737390  2821 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W1223 19:28:52.742431  2821 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 19:28:55] root INFO: loaded pretrained_model successful from rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy.pdparams
[2021/12/23 19:28:55] root INFO: train dataloader has 17 iters
[2021/12/23 19:28:55] root INFO: valid dataloader has 9 iters
[2021/12/23 19:28:55] root INFO: During the training process, after the 0th iteration, an evaluation is run every 200 iterations
[2021/12/23 19:28:55] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:29:00] root INFO: epoch: [1/40], iter: 10, lr: 0.000500, loss: 8.913865, acc: 0.195312, norm_edit_dis: 0.686087, reader_cost: 0.24545 s, batch_cost: 0.41798 s, samples: 2816, ips: 673.71029
[2021/12/23 19:29:01] root INFO: epoch: [1/40], iter: 16, lr: 0.000500, loss: 7.154922, acc: 0.222656, norm_edit_dis: 0.684251, reader_cost: 0.00006 s, batch_cost: 0.06422 s, samples: 1536, ips: 2391.80670
[2021/12/23 19:29:01] root INFO: save model in ./output/rec/ic15/latest
[2021/12/23 19:29:01] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:29:05] root INFO: epoch: [2/40], iter: 20, lr: 0.000500, loss: 6.198568, acc: 0.246094, norm_edit_dis: 0.688872, reader_cost: 0.20438 s, batch_cost: 0.34878 s, samples: 1024, ips: 293.59625
[2021/12/23 19:29:06] root INFO: epoch: [2/40], iter: 30, lr: 0.000500, loss: 4.117401, acc: 0.402344, norm_edit_dis: 0.739680, reader_cost: 0.00016 s, batch_cost: 0.11199 s, samples: 2560, ips: 2285.86780
^C
main proc 2870 exit, kill process group 2821
main proc 2869 exit, kill process group 2821
main proc 2868 exit, kill process group 2821
main proc 2867 exit, kill process group 2821
main proc 2866 exit, kill process group 2821
main proc 2865 exit, kill process group 2821
main proc 2864 exit, kill process group 2821
main proc 2821 exit, kill process group 2821

根据配置文件中设置的的 save_model_dir 字段,会有以下几种参数被保存下来:

output/rec/ic15
├── best_accuracy.pdopt  
├── best_accuracy.pdparams  
├── best_accuracy.states  
├── config.yml  
├── iter_epoch_3.pdopt  
├── iter_epoch_3.pdparams  
├── iter_epoch_3.states  
├── latest.pdopt  
├── latest.pdparams  
├── latest.states  
└── train.log

其中 bestaccuracy. 是评估集上的最优模型;iterepoch_x. 是以 save_epoch_step 为间隔保存下来的模型;latest.* 是最后一个epoch的模型。

总结:

如果需要训练自己的数据需要修改:

  1. 训练和评估数据路径(必须)
  2. 字典路径(必须)
  3. 预训练模型 (可选)
  4. 学习率、image shape、网络结构(可选)

4.2 模型评估

评估数据集可以通过 configs/rec/rec_icdar15_train.yml 修改Eval中的 label_file_path 设置。

这里默认使用 icdar2015 的评估集,加载刚刚训练好的模型权重:

!python tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy \
        Global.character_dict_path=ppocr/utils/ic15_dict.txt
[2021/12/23 14:27:51] root INFO: Architecture : 
[2021/12/23 14:27:51] root INFO:     Backbone : 
[2021/12/23 14:27:51] root INFO:         model_name : large
[2021/12/23 14:27:51] root INFO:         name : MobileNetV3
[2021/12/23 14:27:51] root INFO:         scale : 0.5
[2021/12/23 14:27:51] root INFO:     Head : 
[2021/12/23 14:27:51] root INFO:         fc_decay : 0
[2021/12/23 14:27:51] root INFO:         name : CTCHead
[2021/12/23 14:27:51] root INFO:     Neck : 
[2021/12/23 14:27:51] root INFO:         encoder_type : rnn
[2021/12/23 14:27:51] root INFO:         hidden_size : 96
[2021/12/23 14:27:51] root INFO:         name : SequenceEncoder
[2021/12/23 14:27:51] root INFO:     Transform : None
[2021/12/23 14:27:51] root INFO:     algorithm : CRNN
[2021/12/23 14:27:51] root INFO:     model_type : rec
[2021/12/23 14:27:51] root INFO: Eval : 
[2021/12/23 14:27:51] root INFO:     dataset : 
[2021/12/23 14:27:51] root INFO:         data_dir : ./train_data/ic15_data
[2021/12/23 14:27:51] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 14:27:51] root INFO:         name : SimpleDataSet
[2021/12/23 14:27:51] root INFO:         transforms : 
[2021/12/23 14:27:51] root INFO:             DecodeImage : 
[2021/12/23 14:27:51] root INFO:                 channel_first : False
[2021/12/23 14:27:51] root INFO:                 img_mode : BGR
[2021/12/23 14:27:51] root INFO:             CTCLabelEncode : None
[2021/12/23 14:27:51] root INFO:             RecResizeImg : 
[2021/12/23 14:27:51] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 14:27:51] root INFO:             KeepKeys : 
[2021/12/23 14:27:51] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 14:27:51] root INFO:     loader : 
[2021/12/23 14:27:51] root INFO:         batch_size_per_card : 256
[2021/12/23 14:27:51] root INFO:         drop_last : False
[2021/12/23 14:27:51] root INFO:         num_workers : 4
[2021/12/23 14:27:51] root INFO:         shuffle : False
[2021/12/23 14:27:51] root INFO:         use_shared_memory : False
[2021/12/23 14:27:51] root INFO: Global : 
[2021/12/23 14:27:51] root INFO:     cal_metric_during_train : True
[2021/12/23 14:27:51] root INFO:     character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 14:27:51] root INFO:     character_type : EN
[2021/12/23 14:27:51] root INFO:     checkpoints : output/rec/ic15/best_accuracy
[2021/12/23 14:27:51] root INFO:     debug : False
[2021/12/23 14:27:51] root INFO:     distributed : False
[2021/12/23 14:27:51] root INFO:     epoch_num : 72
[2021/12/23 14:27:51] root INFO:     eval_batch_step : [0, 2000]
[2021/12/23 14:27:51] root INFO:     infer_img : doc/imgs_words_en/word_10.png
[2021/12/23 14:27:51] root INFO:     infer_mode : False
[2021/12/23 14:27:51] root INFO:     log_smooth_window : 20
[2021/12/23 14:27:51] root INFO:     max_text_length : 25
[2021/12/23 14:27:51] root INFO:     pretrained_model : None
[2021/12/23 14:27:51] root INFO:     print_batch_step : 10
[2021/12/23 14:27:51] root INFO:     save_epoch_step : 3
[2021/12/23 14:27:51] root INFO:     save_inference_dir : ./
[2021/12/23 14:27:51] root INFO:     save_model_dir : ./output/rec/ic15/
[2021/12/23 14:27:51] root INFO:     save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 14:27:51] root INFO:     use_gpu : True
[2021/12/23 14:27:51] root INFO:     use_space_char : False
[2021/12/23 14:27:51] root INFO:     use_visualdl : False
[2021/12/23 14:27:51] root INFO: Loss : 
[2021/12/23 14:27:51] root INFO:     name : CTCLoss
[2021/12/23 14:27:51] root INFO: Metric : 
[2021/12/23 14:27:51] root INFO:     main_indicator : acc
[2021/12/23 14:27:51] root INFO:     name : RecMetric
[2021/12/23 14:27:51] root INFO: Optimizer : 
[2021/12/23 14:27:51] root INFO:     beta1 : 0.9
[2021/12/23 14:27:51] root INFO:     beta2 : 0.999
[2021/12/23 14:27:51] root INFO:     lr : 
[2021/12/23 14:27:51] root INFO:         learning_rate : 0.0005
[2021/12/23 14:27:51] root INFO:     name : Adam
[2021/12/23 14:27:51] root INFO:     regularizer : 
[2021/12/23 14:27:51] root INFO:         factor : 0
[2021/12/23 14:27:51] root INFO:         name : L2
[2021/12/23 14:27:51] root INFO: PostProcess : 
[2021/12/23 14:27:51] root INFO:     name : CTCLabelDecode
[2021/12/23 14:27:51] root INFO: Train : 
[2021/12/23 14:27:51] root INFO:     dataset : 
[2021/12/23 14:27:51] root INFO:         data_dir : ./train_data/ic15_data/
[2021/12/23 14:27:51] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 14:27:51] root INFO:         name : SimpleDataSet
[2021/12/23 14:27:51] root INFO:         transforms : 
[2021/12/23 14:27:51] root INFO:             DecodeImage : 
[2021/12/23 14:27:51] root INFO:                 channel_first : False
[2021/12/23 14:27:51] root INFO:                 img_mode : BGR
[2021/12/23 14:27:51] root INFO:             CTCLabelEncode : None
[2021/12/23 14:27:51] root INFO:             RecResizeImg : 
[2021/12/23 14:27:51] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 14:27:51] root INFO:             KeepKeys : 
[2021/12/23 14:27:51] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 14:27:51] root INFO:     loader : 
[2021/12/23 14:27:51] root INFO:         batch_size_per_card : 256
[2021/12/23 14:27:51] root INFO:         drop_last : True
[2021/12/23 14:27:51] root INFO:         num_workers : 8
[2021/12/23 14:27:51] root INFO:         shuffle : True
[2021/12/23 14:27:51] root INFO:         use_shared_memory : False
[2021/12/23 14:27:51] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
[2021/12/23 14:27:51] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
W1223 14:27:51.861889  5192 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1223 14:27:51.865501  5192 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 14:27:56] root INFO: resume from output/rec/ic15/best_accuracy
[2021/12/23 14:27:56] root INFO: metric in ckpt ***************
[2021/12/23 14:27:56] root INFO: acc:0.48531535869041886
[2021/12/23 14:27:56] root INFO: norm_edit_dis:0.7895228681338454
[2021/12/23 14:27:56] root INFO: fps:3266.1877400927865
[2021/12/23 14:27:56] root INFO: best_epoch:24
[2021/12/23 14:27:56] root INFO: start_epoch:25
eval model:: 100%|████████████████████████████████| 9/9 [00:02<00:00,  3.32it/s]
[2021/12/23 14:27:59] root INFO: metric eval ***************
[2021/12/23 14:27:59] root INFO: acc:0.48531535869041886
[2021/12/23 14:27:59] root INFO: norm_edit_dis:0.7895228681338454
[2021/12/23 14:27:59] root INFO: fps:4491.015930181665

评估后,可以看到训练模型在验证集上的精度。

PaddleOCR支持训练和评估交替进行, 可在 configs/rec/rec_icdar15_train.yml 中修改 eval_batch_step 设置评估频率,默认每2000个iter评估一次。评估过程中默认将最佳acc模型,保存为 output/rec/ic15/best_accuracy

如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。

4.3 预测

使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。

预测图片:
2.文本识别实践部分 - 图9

默认预测图片存储在 infer_img 里,通过 -o Global.checkpoints 加载训练好的参数文件:

!python tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy Global.character_dict_path=ppocr/utils/ic15_dict.txt
[2021/12/23 14:29:19] root INFO: Architecture : 
[2021/12/23 14:29:19] root INFO:     Backbone : 
[2021/12/23 14:29:19] root INFO:         model_name : large
[2021/12/23 14:29:19] root INFO:         name : MobileNetV3
[2021/12/23 14:29:19] root INFO:         scale : 0.5
[2021/12/23 14:29:19] root INFO:     Head : 
[2021/12/23 14:29:19] root INFO:         fc_decay : 0
[2021/12/23 14:29:19] root INFO:         name : CTCHead
[2021/12/23 14:29:19] root INFO:     Neck : 
[2021/12/23 14:29:19] root INFO:         encoder_type : rnn
[2021/12/23 14:29:19] root INFO:         hidden_size : 96
[2021/12/23 14:29:19] root INFO:         name : SequenceEncoder
[2021/12/23 14:29:19] root INFO:     Transform : None
[2021/12/23 14:29:19] root INFO:     algorithm : CRNN
[2021/12/23 14:29:19] root INFO:     model_type : rec
[2021/12/23 14:29:19] root INFO: Eval : 
[2021/12/23 14:29:19] root INFO:     dataset : 
[2021/12/23 14:29:19] root INFO:         data_dir : ./train_data/ic15_data
[2021/12/23 14:29:19] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 14:29:19] root INFO:         name : SimpleDataSet
[2021/12/23 14:29:19] root INFO:         transforms : 
[2021/12/23 14:29:19] root INFO:             DecodeImage : 
[2021/12/23 14:29:19] root INFO:                 channel_first : False
[2021/12/23 14:29:19] root INFO:                 img_mode : BGR
[2021/12/23 14:29:19] root INFO:             CTCLabelEncode : None
[2021/12/23 14:29:19] root INFO:             RecResizeImg : 
[2021/12/23 14:29:19] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 14:29:19] root INFO:             KeepKeys : 
[2021/12/23 14:29:19] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 14:29:19] root INFO:     loader : 
[2021/12/23 14:29:19] root INFO:         batch_size_per_card : 256
[2021/12/23 14:29:19] root INFO:         drop_last : False
[2021/12/23 14:29:19] root INFO:         num_workers : 4
[2021/12/23 14:29:19] root INFO:         shuffle : False
[2021/12/23 14:29:19] root INFO:         use_shared_memory : False
[2021/12/23 14:29:19] root INFO: Global : 
[2021/12/23 14:29:19] root INFO:     cal_metric_during_train : True
[2021/12/23 14:29:19] root INFO:     character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 14:29:19] root INFO:     character_type : EN
[2021/12/23 14:29:19] root INFO:     checkpoints : output/rec/ic15/best_accuracy
[2021/12/23 14:29:19] root INFO:     debug : False
[2021/12/23 14:29:19] root INFO:     distributed : False
[2021/12/23 14:29:19] root INFO:     epoch_num : 72
[2021/12/23 14:29:19] root INFO:     eval_batch_step : [0, 2000]
[2021/12/23 14:29:19] root INFO:     infer_img : doc/imgs_words_en/word_19.png
[2021/12/23 14:29:19] root INFO:     infer_mode : False
[2021/12/23 14:29:19] root INFO:     log_smooth_window : 20
[2021/12/23 14:29:19] root INFO:     max_text_length : 25
[2021/12/23 14:29:19] root INFO:     pretrained_model : None
[2021/12/23 14:29:19] root INFO:     print_batch_step : 10
[2021/12/23 14:29:19] root INFO:     save_epoch_step : 3
[2021/12/23 14:29:19] root INFO:     save_inference_dir : ./
[2021/12/23 14:29:19] root INFO:     save_model_dir : ./output/rec/ic15/
[2021/12/23 14:29:19] root INFO:     save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 14:29:19] root INFO:     use_gpu : True
[2021/12/23 14:29:19] root INFO:     use_space_char : False
[2021/12/23 14:29:19] root INFO:     use_visualdl : False
[2021/12/23 14:29:19] root INFO: Loss : 
[2021/12/23 14:29:19] root INFO:     name : CTCLoss
[2021/12/23 14:29:19] root INFO: Metric : 
[2021/12/23 14:29:19] root INFO:     main_indicator : acc
[2021/12/23 14:29:19] root INFO:     name : RecMetric
[2021/12/23 14:29:19] root INFO: Optimizer : 
[2021/12/23 14:29:19] root INFO:     beta1 : 0.9
[2021/12/23 14:29:19] root INFO:     beta2 : 0.999
[2021/12/23 14:29:19] root INFO:     lr : 
[2021/12/23 14:29:19] root INFO:         learning_rate : 0.0005
[2021/12/23 14:29:19] root INFO:     name : Adam
[2021/12/23 14:29:19] root INFO:     regularizer : 
[2021/12/23 14:29:19] root INFO:         factor : 0
[2021/12/23 14:29:19] root INFO:         name : L2
[2021/12/23 14:29:19] root INFO: PostProcess : 
[2021/12/23 14:29:19] root INFO:     name : CTCLabelDecode
[2021/12/23 14:29:19] root INFO: Train : 
[2021/12/23 14:29:19] root INFO:     dataset : 
[2021/12/23 14:29:19] root INFO:         data_dir : ./train_data/ic15_data/
[2021/12/23 14:29:19] root INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 14:29:19] root INFO:         name : SimpleDataSet
[2021/12/23 14:29:19] root INFO:         transforms : 
[2021/12/23 14:29:19] root INFO:             DecodeImage : 
[2021/12/23 14:29:19] root INFO:                 channel_first : False
[2021/12/23 14:29:19] root INFO:                 img_mode : BGR
[2021/12/23 14:29:19] root INFO:             CTCLabelEncode : None
[2021/12/23 14:29:19] root INFO:             RecResizeImg : 
[2021/12/23 14:29:19] root INFO:                 image_shape : [3, 32, 100]
[2021/12/23 14:29:19] root INFO:             KeepKeys : 
[2021/12/23 14:29:19] root INFO:                 keep_keys : ['image', 'label', 'length']
[2021/12/23 14:29:19] root INFO:     loader : 
[2021/12/23 14:29:19] root INFO:         batch_size_per_card : 256
[2021/12/23 14:29:19] root INFO:         drop_last : True
[2021/12/23 14:29:19] root INFO:         num_workers : 8
[2021/12/23 14:29:19] root INFO:         shuffle : True
[2021/12/23 14:29:19] root INFO:         use_shared_memory : False
[2021/12/23 14:29:19] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
W1223 14:29:19.803710  5290 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1223 14:29:19.807695  5290 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 14:29:25] root INFO: resume from output/rec/ic15/best_accuracy
[2021/12/23 14:29:25] root INFO: infer_img: doc/imgs_words_en/word_19.png
pred idx: Tensor(shape=[1, 25], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
       [[29, 0 , 0 , 0 , 22, 0 , 0 , 0 , 25, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 33]])
[2021/12/23 14:29:25] root INFO:      result: slow    0.8795223
[2021/12/23 14:29:25] root INFO: success!

得到输入图像的预测结果:

infer_img: doc/imgs_words_en/word_19.png
        result: slow    0.8795223

作业

【题目1】

可视化出 PaddleOCR 中的实现的数据增强结果:noise、jitter, 并用语言解释效果。

可选测试图片:
https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_1.jpg
https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_2.jpg

【题目2】

更换 configs/rec/rec_icdar15_train.yml 配置中的 backbone 为 PaddleOCR 中的 ResNet34_vd,当输入图片shape为(3,32,100)时,Head 层最终输出的特征尺寸是多少?

【题目3】

下载10W中文数据集rec_data_lesson_demo,修改 configs/rec/rec_icdar15_train.yml 配置文件训练一个识别模型,提供训练log。

可加载预训练模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar

总结

至此,一个基于CRNN的文本识别任务就全部完成了,更多功能和代码可以参考 PaddleOCR

如果对项目任何问题或者疑问,欢迎在评论区留言提出