本节将介绍如何使用PaddleOCR完成文本检测DB算法的训练与运行,包括:

  1. 快速调用paddleocr包体验文本检测
  2. 理解文本检测DB算法原理
  3. 掌握文本检测模型构建流程
  4. 掌握文本检测模型训练流程

1. 快速开始

本节以paddleocr为例,介绍如何三个步骤快速实现文本检测。

  1. 安装paddleocr
  2. 一行命令运行DB算法得到检测结果
  3. 可视化文本检测结果

安装paddleocr whl包

  1. !pip install --upgrade pip
  2. !pip install paddleocr
  1. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  2. Collecting pip
  3. [?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)
  4.  |████████████████████████████████| 1.7MB 17.5MB/s eta 0:00:01
  5. [?25hInstalling collected packages: pip
  6. Found existing installation: pip 19.2.3
  7. Uninstalling pip-19.2.3:
  8. Successfully uninstalled pip-19.2.3
  9. Successfully installed pip-21.3.1
  10. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  11. Collecting paddleocr
  12. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/b6/5486e674ce096667dff247b58bf0fb789c2ce17a10e546c2686a2bb07aec/paddleocr-2.3.0.2-py3-none-any.whl (250 kB)
  13. |████████████████████████████████| 250 kB 7.5 MB/s
  14. [?25hRequirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (0.29)
  15. Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (1.20.3)
  16. Collecting scikit-image
  17. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)
  18. |████████████████████████████████| 13.3 MB 19.8 MB/s
  19. [?25hCollecting lxml
  20. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)
  21. |████████████████████████████████| 6.4 MB 7.7 MB/s
  22. [?25hCollecting shapely
  23. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)
  24. |████████████████████████████████| 1.1 MB 15.0 MB/s
  25. [?25hCollecting python-Levenshtein
  26. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)
  27. |████████████████████████████████| 50 kB 2.4 MB/s
  28. [?25h Preparing metadata (setup.py) ... [?25ldone
  29. [?25hCollecting opencv-contrib-python==4.4.0.46
  30. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)
  31. |████████████████████████████████| 55.7 MB 5.9 MB/s
  32. [?25hRequirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (4.27.0)
  33. Collecting fasttext==0.9.1
  34. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)
  35. |████████████████████████████████| 57 kB 276 kB/s
  36. [?25h Preparing metadata (setup.py) ... [?25ldone
  37. [?25hCollecting pyclipper
  38. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)
  39. |████████████████████████████████| 603 kB 4.1 MB/s
  40. [?25hRequirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (3.0.5)
  41. Collecting premailer
  42. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)
  43. Collecting imgaug==0.4.0
  44. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
  45. |████████████████████████████████| 948 kB 4.1 MB/s
  46. [?25hCollecting lmdb
  47. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)
  48. |████████████████████████████████| 299 kB 7.4 MB/s
  49. [?25hRequirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (2.2.0)
  50. Collecting pybind11>=2.2
  51. Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)
  52. Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->paddleocr) (56.2.0)
  53. Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (4.1.1.26)
  54. Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (7.1.2)
  55. Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.15.0)
  56. Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.6.1)
  57. Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.2.3)
  58. Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.6.3)
  59. Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (2.4)
  60. Collecting tifffile>=2019.7.26
  61. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)
  62. |████████████████████████████████| 178 kB 4.2 MB/s
  63. [?25hCollecting PyWavelets>=1.1.1
  64. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)
  65. |████████████████████████████████| 6.1 MB 135 kB/s
  66. [?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (20.9)
  67. Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.4.1)
  68. Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.0.1)
  69. Collecting cssutils
  70. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)
  71. |████████████████████████████████| 404 kB 3.5 MB/s
  72. [?25hRequirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (2.22.0)
  73. Collecting cssselect
  74. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
  75. Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (4.0.0)
  76. Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.8.53)
  77. Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.0.0)
  78. Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.8.2)
  79. Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.21.0)
  80. Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.7.1.1)
  81. Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.14.0)
  82. Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.1)
  83. Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.5)
  84. Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.6.1)
  85. Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.6.0)
  86. Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.2.0)
  87. Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.23)
  88. Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (7.0)
  89. Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (2.11.0)
  90. Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (1.1.0)
  91. Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (0.16.0)
  92. Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2.8.0)
  93. Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2019.3)
  94. Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->paddleocr) (4.4.2)
  95. Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->paddleocr) (2.4.2)
  96. Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (3.9.9)
  97. Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (0.18.0)
  98. Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (0.10.0)
  99. Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (2.8.0)
  100. Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (1.1.0)
  101. Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (0.10.0)
  102. Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.0)
  103. Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.4.10)
  104. Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (5.1.2)
  105. Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.4)
  106. Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (16.7.9)
  107. Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (2.0.1)
  108. Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2.8)
  109. Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (1.25.6)
  110. Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2019.9.11)
  111. Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (3.0.4)
  112. Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddleocr) (1.1.1)
  113. Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->paddleocr) (3.6.0)
  114. Building wheels for collected packages: fasttext, python-Levenshtein
  115. Building wheel for fasttext (setup.py) ... [?25ldone
  116. [?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2585570 sha256=8e6e5fb7b1dfe72716dad20aae58de17212a69d82dfd040e6b33fe76e6997c8c
  117. Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36
  118. Building wheel for python-Levenshtein (setup.py) ... [?25ldone
  119. [?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171688 sha256=a09d3da400e3e6661d47bd7146d40691b9e9403decfac4ce55d593dc1f2aa6ad
  120. Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446
  121. Successfully built fasttext python-Levenshtein
  122. Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext, paddleocr
  123. Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 paddleocr-2.3.0.2 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2

一行命令实现文本检测

初次运行时,paddleocr会自动下载并使用PaddleOCR的PP-OCRv2轻量级模型

使用安装好的paddleocr 以./doc/imgs/12.jpg为输入图像,将得到以下预测结果:
image.png
图 12.jpg

  1. [[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]]
  2. [[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]]
  3. [[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]]
  4. [[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]

预测结果一共包含四个文本框,每一行包含四个坐标点,代表一个文本框的坐标集合,从左上角起以顺时针顺序排列。

paddleocr命令行调用文本检测模型预测图像./doc/imgs/12.jpg的方式如下:

  1. # --image_dir 指向要预测的图像路径 --rec false表示不使用识别识别,只执行文本检测
  2. ! paddleocr --image_dir ./PaddleOCR/doc/imgs/12.jpg --rec false
  1. [2021/12/22 14:34:44] root WARNING: version PP-OCRv2 not support cls models, auto switch to version PP-OCR
  2. download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer/ch_PP-OCRv2_det_infer.tar
  3. 100%|█████████████████████████████████████| 3.19M/3.19M [00:00<00:00, 43.1MiB/s]
  4. download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer/ch_PP-OCRv2_rec_infer.tar
  5. 100%|█████████████████████████████████████| 8.88M/8.88M [00:00<00:00, 50.3MiB/s]
  6. download https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar
  7. 100%|█████████████████████████████████████| 1.45M/1.45M [00:00<00:00, 33.8MiB/s]
  8. Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer', det_pse_box_thresh=0.85, det_pse_box_type='box', det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir='./PaddleOCR/doc/imgs/12.jpg', ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, ocr_version='PP-OCRv2', output='./output/table', precision='fp32', process_id=0, rec=False, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, structure_version='STRUCTURE', table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_onnx=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, vis_font_path='./doc/fonts/simfang.ttf', warmup=True)
  9. [2021/12/22 14:34:47] root INFO: **********./PaddleOCR/doc/imgs/12.jpg**********
  10. [2021/12/22 14:34:49] root INFO: [[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]]
  11. [2021/12/22 14:34:49] root INFO: [[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]]
  12. [2021/12/22 14:34:49] root INFO: [[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]]
  13. [2021/12/22 14:34:49] root INFO: [[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]
  1. 另外,除了命令行使用方式,paddleocr也提供了代码调用方式,如下:
  1. # 首次运行需要打开下一行的注释,下载PaddleOCR代码
  2. #!git clone https://gitee.com/paddlepaddle/PaddleOCR
  3. import os
  4. # 修改代码运行的默认目录为 /home/aistudio/PaddleOCR
  5. os.chdir("/home/aistudio/PaddleOCR")
  6. # 安装PaddleOCR第三方依赖
  7. !pip install --upgrade pip
  8. !pip install -r requirements.txt
  1. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  2. Requirement already satisfied: pip in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (21.3.1)
  3. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  4. Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)
  5. Requirement already satisfied: scikit-image==0.17.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.17.2)
  6. Requirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)
  7. Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)
  8. Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)
  9. Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.27.0)
  10. Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)
  11. Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)
  12. Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)
  13. Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)
  14. Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)
  15. Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)
  16. Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)
  17. Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)
  18. Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)
  19. Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)
  20. Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)
  21. Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)
  22. Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)
  23. Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)
  24. Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)
  25. Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)
  26. Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)
  27. Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)
  28. Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)
  29. Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)
  30. Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)
  31. Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)
  32. Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)
  33. Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)
  34. Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)
  35. Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)
  36. Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)
  37. Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)
  38. Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)
  39. Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)
  40. Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)
  41. Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)
  42. Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)
  43. Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)
  44. Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)
  45. Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)
  46. Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)
  47. Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)
  48. Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)
  49. Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)
  50. Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)
  51. Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)
  52. Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)
  53. Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)
  54. Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)
  55. Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)
  56. Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)
  57. Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)
  58. Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)
  59. Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)
  60. Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)
  61. Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)
  62. Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)
  63. Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)
  64. Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)
  65. Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)
  66. Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)
  67. Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)
  68. Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)
  69. Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)
  70. Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)
  1. # 1. 从paddleocr中import PaddleOCR类
  2. from paddleocr import PaddleOCR
  3. import numpy as np
  4. import cv2
  5. import matplotlib.pyplot as plt
  6. # 在notebook中使用matplotlib.pyplot绘图时,需要添加该命令进行显示
  7. %matplotlib inline
  8. # 2. 声明PaddleOCR类
  9. ocr = PaddleOCR()
  10. img_path = './PaddleOCR/doc/imgs/12.jpg'
  11. # 3. 执行预测
  12. result = ocr.ocr(img_path, rec=False)
  13. print(f"The predicted text box of {img_path} are follows.")
  14. print(result)
  1. [2021/12/22 14:35:01] root WARNING: version PP-OCRv2 not support cls models, auto switch to version PP-OCR
  2. Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer', det_pse_box_thresh=0.85, det_pse_box_type='box', det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir=None, ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, ocr_version='PP-OCRv2', output='./output/table', precision='fp32', process_id=0, rec=True, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, structure_version='STRUCTURE', table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_onnx=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, vis_font_path='./doc/fonts/simfang.ttf', warmup=True)
  3. [2021/12/22 14:35:03] root WARNING: Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process
  4. The predicted text box of ./PaddleOCR/doc/imgs/12.jpg are follows.
  5. [[[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]], [[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]], [[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]], [[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]]

可视化文本检测预测结果

  1. # 4. 可视化检测结果
  2. image = cv2.imread(img_path)
  3. boxes = [line[0] for line in result]
  4. for box in result:
  5. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  6. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  7. # 画出读取的图片
  8. plt.figure(figsize=(10, 10))
  9. plt.imshow(image)
  1. <matplotlib.image.AxesImage at 0x7f23d478d0d0>

image.png

2. DB文本检测算法详细实现

2.1 DB文本检测算法原理

DB是一个基于分割的文本检测算法,其提出可微分阈值Differenttiable Binarization module(DB module)采用动态的阈值区分文本区域与背景。

2.OCR文本检测实战 - 图3
图1 DB模型与其他方法的区别
基于分割的普通文本检测算法其流程如上图中的蓝色箭头所示,此类方法得到分割结果之后采用一个固定的阈值得到二值化的分割图,之后采用诸如像素聚类的启发式算法得到文本区域。

DB算法的流程如图中红色箭头所示,最大的不同在于DB有一个阈值图,通过网络去预测图片每个位置处的阈值,而不是采用一个固定的值,更好的分离文本背景与前景。

DB算法有以下几个优势:

  1. 算法结构简单,无需繁琐的后处理
  2. 在开源数据上拥有良好的精度和性能

在传统的图像分割算法中,获取概率图后,会使用标准二值化(Standard Binarize)方法进行处理,将低于阈值的像素点置0,高于阈值的像素点置1,公式如下:

2.OCR文本检测实战 - 图4

但是标准的二值化方法是不可微的,导致网络无法端对端训练。为了解决这个问题,DB算法提出了可微二值化(Differentiable Binarization,DB)。可微二值化将标准二值化中的阶跃函数进行了近似,使用如下公式进行代替:

2.OCR文本检测实战 - 图5%7D%7D%0A#card=math&code=%5Chat%7BB%7D%20%3D%20%5Cfrac%7B1%7D%7B1%20%2B%20e%5E%7B-k%28P%7Bi%2Cj%7D-T%7Bi%2Cj%7D%29%7D%7D%0A&id=zcfxq)

其中,P是上文中获取的概率图,T是上文中获取的阈值图,k是增益因子,在实验中,根据经验选取为50。标准二值化和可微二值化的对比图如 下图3(a) 所示。

当使用交叉熵损失时,正负样本的loss分别为 2.OCR文本检测实战 - 图62.OCR文本检测实战 - 图7

2.OCR文本检测实战 - 图8%7D%7D)%0A#card=math&code=l%2B%20%3D%20-log%28%5Cfrac%7B1%7D%7B1%20%2B%20e%5E%7B-k%28P%7Bi%2Cj%7D-T_%7Bi%2Cj%7D%29%7D%7D%29%0A&id=ZSptB)

2.OCR文本检测实战 - 图9%7D%7D)%0A#card=math&code=l-%20%3D%20-log%281-%5Cfrac%7B1%7D%7B1%20%2B%20e%5E%7B-k%28P%7Bi%2Cj%7D-T_%7Bi%2Cj%7D%29%7D%7D%29%0A&id=N9uYb)

对输入 2.OCR文本检测实战 - 图10 求偏导则会得到:

2.OCR文本检测实战 - 图11e%5E%7B-kx%7D%0A#card=math&code=%5Cfrac%7B%5Cdelta%7Bl_%2B%7D%7D%7B%5Cdelta%7Bx%7D%7D%20%3D%20-kf%28x%29e%5E%7B-kx%7D%0A&id=YrxMI)

2.OCR文本检测实战 - 图12%0A#card=math&code=%5Cfrac%7B%5Cdelta%7Bl_-%7D%7D%7B%5Cdelta%7Bx%7D%7D%20%3D%20-kf%28x%29%0A&id=wtRgy)

可以发现,增强因子会放大错误预测的梯度,从而优化模型得到更好的结果。图3(b) 中,2.OCR文本检测实战 - 图13 的部分为正样本预测为负样本的情况,可以看到,增益因子k将梯度进行了放大;而 图3(c)2.OCR文本检测实战 - 图14 的部分为负样本预测为正样本时,梯度同样也被放大了。

2.OCR文本检测实战 - 图15

DB算法整体结构如下图所示:

2.OCR文本检测实战 - 图16

输入的图像经过网络Backbone和FPN提取特征,提取后的特征级联在一起,得到原图四分之一大小的特征,然后利用卷积层分别得到文本区域预测概率图和阈值图,进而通过DB的后处理得到文本包围曲线。

2.2 DB文本检测模型构建

DB文本检测模型可以分为三个部分:

  • Backbone网络,负责提取图像的特征
  • FPN网络,特征金字塔结构增强特征
  • Head网络,计算文本区域概率图

本节使用PaddlePaddle分别实现上述三个网络模块,并完成完整的网络构建。

backbone网络

DB文本检测网络的Backbone部分采用的是图像分类网络,论文中使用了ResNet50,本节实验中,为了加快训练速度,采用MobileNetV3 large结构作为backbone。

  1. # 首次运行需要打开下一行的注释,下载PaddleOCR代码
  2. #!git clone https://gitee.com/paddlepaddle/PaddleOCR
  3. import os
  4. # 修改代码运行的默认目录为 /home/aistudio/PaddleOCR
  5. os.chdir("/home/aistudio/PaddleOCR")
  6. # 安装PaddleOCR第三方依赖
  7. !pip install --upgrade pip
  8. !pip install -r requirements.txt
  1. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  2. Requirement already satisfied: pip in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (21.3.1)
  1. # https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/backbones/det_mobilenet_v3.py
  2. from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3

如果您希望使用ResNet作为Backbone训练,可以在PaddleOCR代码中选择ResNet,或者从PaddleClas中选择backbone模型。

DB的Backbone用于提取图像的多尺度特征,如下代码所示,假设输入的形状为[640, 640],backbone网络的输出有四个特征,其形状分别是 [1, 16, 160, 160],[1, 24, 80, 80], [1, 56, 40, 40],[1, 480, 20, 20]。
这些特征将输入给特征金字塔FPN网络进一步的增强特征。

  1. import paddle
  2. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  3. # 1. 声明Backbone
  4. model_backbone = MobileNetV3()
  5. model_backbone.eval()
  6. # 2. 执行预测
  7. outs = model_backbone(fake_inputs)
  8. # 3. 打印网络结构
  9. print(model_backbone)
  10. # 4. 打印输出特征形状
  11. for idx, out in enumerate(outs):
  12. print("The index is ", idx, "and the shape of output is ", out.shape)
  1. W1222 14:40:35.323043 565 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
  2. W1222 14:40:35.328037 565 device_context.cc:465] device: 0, cuDNN Version: 7.6.
  3. MobileNetV3(
  4. (conv): ConvBNLayer(
  5. (conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
  6. (bn): BatchNorm()
  7. )
  8. (stage0): Sequential(
  9. (0): ResidualUnit(
  10. (expand_conv): ConvBNLayer(
  11. (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
  12. (bn): BatchNorm()
  13. )
  14. (bottleneck_conv): ConvBNLayer(
  15. (conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)
  16. (bn): BatchNorm()
  17. )
  18. (linear_conv): ConvBNLayer(
  19. (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
  20. (bn): BatchNorm()
  21. )
  22. )
  23. (1): ResidualUnit(
  24. (expand_conv): ConvBNLayer(
  25. (conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)
  26. (bn): BatchNorm()
  27. )
  28. (bottleneck_conv): ConvBNLayer(
  29. (conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)
  30. (bn): BatchNorm()
  31. )
  32. (linear_conv): ConvBNLayer(
  33. (conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)
  34. (bn): BatchNorm()
  35. )
  36. )
  37. (2): ResidualUnit(
  38. (expand_conv): ConvBNLayer(
  39. (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
  40. (bn): BatchNorm()
  41. )
  42. (bottleneck_conv): ConvBNLayer(
  43. (conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)
  44. (bn): BatchNorm()
  45. )
  46. (linear_conv): ConvBNLayer(
  47. (conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)
  48. (bn): BatchNorm()
  49. )
  50. )
  51. )
  52. (stage1): Sequential(
  53. (0): ResidualUnit(
  54. (expand_conv): ConvBNLayer(
  55. (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
  56. (bn): BatchNorm()
  57. )
  58. (bottleneck_conv): ConvBNLayer(
  59. (conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)
  60. (bn): BatchNorm()
  61. )
  62. (mid_se): SEModule(
  63. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  64. (conv1): Conv2D(40, 10, kernel_size=[1, 1], data_format=NCHW)
  65. (conv2): Conv2D(10, 40, kernel_size=[1, 1], data_format=NCHW)
  66. )
  67. (linear_conv): ConvBNLayer(
  68. (conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)
  69. (bn): BatchNorm()
  70. )
  71. )
  72. (1): ResidualUnit(
  73. (expand_conv): ConvBNLayer(
  74. (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
  75. (bn): BatchNorm()
  76. )
  77. (bottleneck_conv): ConvBNLayer(
  78. (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
  79. (bn): BatchNorm()
  80. )
  81. (mid_se): SEModule(
  82. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  83. (conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)
  84. (conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW)
  85. )
  86. (linear_conv): ConvBNLayer(
  87. (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
  88. (bn): BatchNorm()
  89. )
  90. )
  91. (2): ResidualUnit(
  92. (expand_conv): ConvBNLayer(
  93. (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
  94. (bn): BatchNorm()
  95. )
  96. (bottleneck_conv): ConvBNLayer(
  97. (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
  98. (bn): BatchNorm()
  99. )
  100. (mid_se): SEModule(
  101. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  102. (conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)
  103. (conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW)
  104. )
  105. (linear_conv): ConvBNLayer(
  106. (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
  107. (bn): BatchNorm()
  108. )
  109. )
  110. )
  111. (stage2): Sequential(
  112. (0): ResidualUnit(
  113. (expand_conv): ConvBNLayer(
  114. (conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)
  115. (bn): BatchNorm()
  116. )
  117. (bottleneck_conv): ConvBNLayer(
  118. (conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)
  119. (bn): BatchNorm()
  120. )
  121. (linear_conv): ConvBNLayer(
  122. (conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)
  123. (bn): BatchNorm()
  124. )
  125. )
  126. (1): ResidualUnit(
  127. (expand_conv): ConvBNLayer(
  128. (conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)
  129. (bn): BatchNorm()
  130. )
  131. (bottleneck_conv): ConvBNLayer(
  132. (conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)
  133. (bn): BatchNorm()
  134. )
  135. (linear_conv): ConvBNLayer(
  136. (conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)
  137. (bn): BatchNorm()
  138. )
  139. )
  140. (2): ResidualUnit(
  141. (expand_conv): ConvBNLayer(
  142. (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
  143. (bn): BatchNorm()
  144. )
  145. (bottleneck_conv): ConvBNLayer(
  146. (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
  147. (bn): BatchNorm()
  148. )
  149. (linear_conv): ConvBNLayer(
  150. (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
  151. (bn): BatchNorm()
  152. )
  153. )
  154. (3): ResidualUnit(
  155. (expand_conv): ConvBNLayer(
  156. (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
  157. (bn): BatchNorm()
  158. )
  159. (bottleneck_conv): ConvBNLayer(
  160. (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
  161. (bn): BatchNorm()
  162. )
  163. (linear_conv): ConvBNLayer(
  164. (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
  165. (bn): BatchNorm()
  166. )
  167. )
  168. (4): ResidualUnit(
  169. (expand_conv): ConvBNLayer(
  170. (conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)
  171. (bn): BatchNorm()
  172. )
  173. (bottleneck_conv): ConvBNLayer(
  174. (conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)
  175. (bn): BatchNorm()
  176. )
  177. (mid_se): SEModule(
  178. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  179. (conv1): Conv2D(240, 60, kernel_size=[1, 1], data_format=NCHW)
  180. (conv2): Conv2D(60, 240, kernel_size=[1, 1], data_format=NCHW)
  181. )
  182. (linear_conv): ConvBNLayer(
  183. (conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)
  184. (bn): BatchNorm()
  185. )
  186. )
  187. (5): ResidualUnit(
  188. (expand_conv): ConvBNLayer(
  189. (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
  190. (bn): BatchNorm()
  191. )
  192. (bottleneck_conv): ConvBNLayer(
  193. (conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)
  194. (bn): BatchNorm()
  195. )
  196. (mid_se): SEModule(
  197. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  198. (conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)
  199. (conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW)
  200. )
  201. (linear_conv): ConvBNLayer(
  202. (conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)
  203. (bn): BatchNorm()
  204. )
  205. )
  206. )
  207. (stage3): Sequential(
  208. (0): ResidualUnit(
  209. (expand_conv): ConvBNLayer(
  210. (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
  211. (bn): BatchNorm()
  212. )
  213. (bottleneck_conv): ConvBNLayer(
  214. (conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)
  215. (bn): BatchNorm()
  216. )
  217. (mid_se): SEModule(
  218. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  219. (conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)
  220. (conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW)
  221. )
  222. (linear_conv): ConvBNLayer(
  223. (conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)
  224. (bn): BatchNorm()
  225. )
  226. )
  227. (1): ResidualUnit(
  228. (expand_conv): ConvBNLayer(
  229. (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
  230. (bn): BatchNorm()
  231. )
  232. (bottleneck_conv): ConvBNLayer(
  233. (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
  234. (bn): BatchNorm()
  235. )
  236. (mid_se): SEModule(
  237. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  238. (conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)
  239. (conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW)
  240. )
  241. (linear_conv): ConvBNLayer(
  242. (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
  243. (bn): BatchNorm()
  244. )
  245. )
  246. (2): ResidualUnit(
  247. (expand_conv): ConvBNLayer(
  248. (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
  249. (bn): BatchNorm()
  250. )
  251. (bottleneck_conv): ConvBNLayer(
  252. (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
  253. (bn): BatchNorm()
  254. )
  255. (mid_se): SEModule(
  256. (avg_pool): AdaptiveAvgPool2D(output_size=1)
  257. (conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)
  258. (conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW)
  259. )
  260. (linear_conv): ConvBNLayer(
  261. (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
  262. (bn): BatchNorm()
  263. )
  264. )
  265. (3): ConvBNLayer(
  266. (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
  267. (bn): BatchNorm()
  268. )
  269. )
  270. )
  271. The index is 0 and the shape of output is [1, 16, 160, 160]
  272. The index is 1 and the shape of output is [1, 24, 80, 80]
  273. The index is 2 and the shape of output is [1, 56, 40, 40]
  274. The index is 3 and the shape of output is [1, 480, 20, 20]

FPN网络

特征金字塔结构FPN是一种卷积网络来高效提取图片中各维度特征的常用方法。

  1. # https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/necks/db_fpn.py
  2. import paddle
  3. from paddle import nn
  4. import paddle.nn.functional as F
  5. from paddle import ParamAttr
  6. class DBFPN(nn.Layer):
  7. def __init__(self, in_channels, out_channels, **kwargs):
  8. super(DBFPN, self).__init__()
  9. self.out_channels = out_channels
  10. # DBFPN详细实现参考: https://github.com/PaddlePaddle/PaddleOCRblob/release%2F2.4/ppocr/modeling/necks/db_fpn.py
  11. def forward(self, x):
  12. c2, c3, c4, c5 = x
  13. in5 = self.in5_conv(c5)
  14. in4 = self.in4_conv(c4)
  15. in3 = self.in3_conv(c3)
  16. in2 = self.in2_conv(c2)
  17. # 特征上采样
  18. out4 = in4 + F.upsample(
  19. in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
  20. out3 = in3 + F.upsample(
  21. out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
  22. out2 = in2 + F.upsample(
  23. out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
  24. p5 = self.p5_conv(in5)
  25. p4 = self.p4_conv(out4)
  26. p3 = self.p3_conv(out3)
  27. p2 = self.p2_conv(out2)
  28. # 特征上采样
  29. p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
  30. p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
  31. p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
  32. fuse = paddle.concat([p5, p4, p3, p2], axis=1)
  33. return fuse

FPN网络的输入为Backbone部分的输出,输出特征图的高度和宽度为原图的四分之一。假设输入图像的形状为[1, 3, 640, 640],FPN输出特征的高度和宽度为[160, 160]

  1. import paddle
  2. # 1. 从PaddleOCR中import DBFPN
  3. from ppocr.modeling.necks.db_fpn import DBFPN
  4. # 2. 获得Backbone网络输出结果
  5. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  6. model_backbone = MobileNetV3()
  7. in_channles = model_backbone.out_channels
  8. # 3. 声明FPN网络
  9. model_fpn = DBFPN(in_channels=in_channles, out_channels=256)
  10. # 4. 打印FPN网络
  11. print(model_fpn)
  12. # 5. 计算得到FPN结果输出
  13. outs = model_backbone(fake_inputs)
  14. fpn_outs = model_fpn(outs)
  15. # 6. 打印FPN输出特征形状
  16. print(f"The shape of fpn outs {fpn_outs.shape}")
  1. DBFPN(
  2. (in2_conv): Conv2D(16, 256, kernel_size=[1, 1], data_format=NCHW)
  3. (in3_conv): Conv2D(24, 256, kernel_size=[1, 1], data_format=NCHW)
  4. (in4_conv): Conv2D(56, 256, kernel_size=[1, 1], data_format=NCHW)
  5. (in5_conv): Conv2D(480, 256, kernel_size=[1, 1], data_format=NCHW)
  6. (p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  7. (p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  8. (p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  9. (p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  10. )
  11. The shape of fpn outs [1, 256, 160, 160]

Head网络

计算文本区域概率图,文本区域阈值图以及文本区域二值图。

  1. import math
  2. import paddle
  3. from paddle import nn
  4. import paddle.nn.functional as F
  5. from paddle import ParamAttr
  6. class DBHead(nn.Layer):
  7. """
  8. Differentiable Binarization (DB) for text detection:
  9. see https://arxiv.org/abs/1911.08947
  10. args:
  11. params(dict): super parameters for build DB network
  12. """
  13. def __init__(self, in_channels, k=50, **kwargs):
  14. super(DBHead, self).__init__()
  15. self.k = k
  16. # DBHead详细实现参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/heads/det_db_head.py
  17. def step_function(self, x, y):
  18. # 可微二值化实现,通过概率图和阈值图计算文本分割二值图
  19. return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
  20. def forward(self, x, targets=None):
  21. shrink_maps = self.binarize(x)
  22. if not self.training:
  23. return {'maps': shrink_maps}
  24. threshold_maps = self.thresh(x)
  25. binary_maps = self.step_function(shrink_maps, threshold_maps)
  26. y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
  27. return {'maps': y}

DB Head网络会在FPN特征的基础上作上采样,将FPN特征由原图的四分之一大小映射到原图大小。

  1. # 1. 从PaddleOCR中imort DBHead
  2. from ppocr.modeling.heads.det_db_head import DBHead
  3. import paddle
  4. # 2. 计算DBFPN网络输出结果
  5. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  6. model_backbone = MobileNetV3()
  7. in_channles = model_backbone.out_channels
  8. model_fpn = DBFPN(in_channels=in_channles, out_channels=256)
  9. outs = model_backbone(fake_inputs)
  10. fpn_outs = model_fpn(outs)
  11. # 3. 声明Head网络
  12. model_db_head = DBHead(in_channels=256)
  13. # 4. 打印DBhead网络
  14. print(model_db_head)
  15. # 5. 计算Head网络的输出
  16. db_head_outs = model_db_head(fpn_outs)
  17. print(f"The shape of fpn outs {fpn_outs.shape}")
  18. print(f"The shape of DB head outs {db_head_outs['maps'].shape}")
  1. DBHead(
  2. (binarize): Head(
  3. (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  4. (conv_bn1): BatchNorm()
  5. (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  6. (conv_bn2): BatchNorm()
  7. (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  8. )
  9. (thresh): Head(
  10. (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  11. (conv_bn1): BatchNorm()
  12. (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  13. (conv_bn2): BatchNorm()
  14. (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  15. )
  16. )
  17. The shape of fpn outs [1, 256, 160, 160]
  18. The shape of DB head outs [1, 3, 640, 640]

3 训练DB文字检测模型

PaddleOCR提供DB文本检测算法,支持MobileNetV3、ResNet50_vd两种骨干网络,可以根据需要选择相应的配置文件,启动训练。

本节以icdar15数据集、MobileNetV3作为骨干网络的DB检测模型(即超轻量模型使用的配置)为例,介绍如何完成PaddleOCR中文字检测模型的训练、评估与测试。

3.1 数据准备

本次实验选取了场景文本检测和识别(Scene Text Detection and Recognition)任务最知名和常用的数据集ICDAR2015。icdar2015数据集的示意图如下图所示:

2.OCR文本检测实战 - 图17

该项目中已经下载了icdar2015数据集,存放在 /home/aistudio/data/data96799 中,可以运行如下指令完成数据集解压,或者从链接中自行下载。

  1. !cd ~/data/data96799/ && tar xf icdar2015.tar

运行上述指令后 ~/train_data/icdar2015/text_localization 有两个文件夹和两个文件,分别是:

  1. ~/train_data/icdar2015/text_localization
  2. └─ icdar_c4_train_imgs/ icdar数据集的训练数据
  3. └─ ch4_test_images/ icdar数据集的测试数据
  4. └─ train_icdar2015_label.txt icdar数据集的训练标注
  5. └─ test_icdar2015_label.txt icdar数据集的测试标注

提供的标注文件格式为:

  1. " 图像文件名 json.dumps编码的图像标注信息"
  2. ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]

json.dumps编码前的图像标注信息是包含多个字典的list,字典中的points表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 transcription中的字段表示当前文本框的文字,在文本检测任务中并不需要这个信息。 如果您想在其他数据集上训练PaddleOCR,可以按照上述形式构建标注文件。

如果”transcription”字段的文字为’*’或者’###‘,表示对应的标注可以被忽略掉,因此,如果没有文字标签,可以将transcription字段设置为空字符串。

3.2 数据预处理

训练时对输入图片的格式、大小有一定的要求,同时,还需要根据标注信息获取阈值图以及概率图的真实标签。所以,在数据输入模型前,需要对数据进行预处理操作,使得图片和标签满足网络训练和预测的需要。另外,为了扩大训练数据集、抑制过拟合,提升模型的泛化能力,还需要使用了几种基础的数据增广方法。

本实验的数据预处理共包括如下方法:

  • 图像解码:将图像转为Numpy格式;
  • 标签解码:解析txt文件中的标签信息,并按统一格式进行保存;
  • 基础数据增广:包括:随机水平翻转、随机旋转,随机缩放,随机裁剪等;
  • 获取阈值图标签:使用扩张的方式获取算法训练需要的阈值图标签;
  • 获取概率图标签:使用收缩的方式获取算法训练需要的概率图标签;
  • 归一化:通过规范化手段,把神经网络每层中任意神经元的输入值分布改变成均值为0,方差为1的标准正太分布,使得最优解的寻优过程明显会变得平缓,训练过程更容易收敛;
  • 通道变换:图像的数据格式为[H, W, C](即高度、宽度和通道数),而神经网络使用的训练数据的格式为[C, H, W],因此需要对图像数据重新排列,例如[224, 224, 3]变为[3, 224, 224];

图像解码

  1. import sys
  2. import six
  3. import cv2
  4. import numpy as np
  1. # https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/operators.py
  2. class DecodeImage(object):
  3. """ decode image """
  4. def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
  5. self.img_mode = img_mode
  6. self.channel_first = channel_first
  7. def __call__(self, data):
  8. img = data['image']
  9. if six.PY2:
  10. assert type(img) is str and len(
  11. img) > 0, "invalid input 'img' in DecodeImage"
  12. else:
  13. assert type(img) is bytes and len(
  14. img) > 0, "invalid input 'img' in DecodeImage"
  15. # 1. 图像解码
  16. img = np.frombuffer(img, dtype='uint8')
  17. img = cv2.imdecode(img, 1)
  18. if img is None:
  19. return None
  20. if self.img_mode == 'GRAY':
  21. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  22. elif self.img_mode == 'RGB':
  23. assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
  24. img = img[:, :, ::-1]
  25. if self.channel_first:
  26. img = img.transpose((2, 0, 1))
  27. # 2. 解码后的图像放在字典中
  28. data['image'] = img
  29. return data

接下来,从训练数据的标注中读取图像,演示DecodeImage类的使用方式。

  1. import json
  2. import cv2
  3. import os
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. # 在notebook中使用matplotlib.pyplot绘图时,需要添加该命令进行显示
  7. %matplotlib inline
  8. from PIL import Image
  9. import numpy as np
  10. label_path = "/home/aistudio/data/data96799/icdar2015/text_localization/train_icdar2015_label.txt"
  11. img_dir = "/home/aistudio/data/data96799/icdar2015/text_localization/"
  12. # 1. 读取训练标签的第一条数据
  13. f = open(label_path, "r")
  14. lines = f.readlines()
  15. # 2. 取第一条数据
  16. line = lines[0]
  17. print("The first data in train_icdar2015_label.txt is as follows.\n", line)
  18. img_name, gt_label = line.strip().split("\t")
  19. # 3. 读取图像
  20. image = open(os.path.join(img_dir, img_name), 'rb').read()
  21. data = {'image': image, 'label': gt_label}
  1. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  2. from collections import MutableMapping
  3. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  4. from collections import Iterable, Mapping
  5. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  6. from collections import Sized
  7. The first data in train_icdar2015_label.txt is as follows.
  8. icdar_c4_train_imgs/img_61.jpg [{"transcription": "###", "points": [[427, 293], [469, 293], [468, 315], [425, 314]]}, {"transcription": "###", "points": [[480, 291], [651, 289], [650, 311], [479, 313]]}, {"transcription": "Ave", "points": [[655, 287], [698, 287], [696, 309], [652, 309]]}, {"transcription": "West", "points": [[701, 285], [759, 285], [759, 308], [701, 308]]}, {"transcription": "YOU", "points": [[1044, 531], [1074, 536], [1076, 585], [1046, 579]]}, {"transcription": "CAN", "points": [[1077, 535], [1114, 539], [1117, 595], [1079, 585]]}, {"transcription": "PAY", "points": [[1119, 539], [1160, 543], [1158, 601], [1120, 593]]}, {"transcription": "LESS?", "points": [[1164, 542], [1252, 545], [1253, 624], [1166, 602]]}, {"transcription": "Singapore's", "points": [[1032, 177], [1185, 73], [1191, 143], [1038, 223]]}, {"transcription": "no.1", "points": [[1190, 73], [1270, 19], [1278, 91], [1194, 133]]}]

声明DecodeImage类,解码图像,并返回一个新的字典data。

  1. # 4. 声明DecodeImage类,解码图像
  2. decode_image = DecodeImage(img_mode='RGB', channel_first=False)
  3. data = decode_image(data)
  4. # 5. 打印解码后图像的shape,并可视化图像
  5. print("The shape of decoded image is ", data['image'].shape)
  6. plt.figure(figsize=(10, 10))
  7. plt.imshow(data['image'])
  8. src_img = data['image']
  1. The shape of decoded image is (720, 1280, 3)
  2. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  3. if isinstance(obj, collections.Iterator):
  4. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  5. return list(data) if isinstance(data, collections.MappingView) else data

2.OCR文本检测实战 - 图18

标签解码

解析txt文件中的标签信息,并按统一格式进行保存;

  1. import numpy as np
  2. import string
  3. import json
  4. # 详细实现参考: https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/label_ops.py#L38
  5. class DetLabelEncode(object):
  6. def __init__(self, **kwargs):
  7. pass
  8. def __call__(self, data):
  9. label = data['label']
  10. # 1. 使用json读入标签
  11. label = json.loads(label)
  12. nBox = len(label)
  13. boxes, txts, txt_tags = [], [], []
  14. for bno in range(0, nBox):
  15. box = label[bno]['points']
  16. txt = label[bno]['transcription']
  17. boxes.append(box)
  18. txts.append(txt)
  19. # 1.1 如果文本标注是*或者###,表示此标注无效
  20. if txt in ['*', '###']:
  21. txt_tags.append(True)
  22. else:
  23. txt_tags.append(False)
  24. if len(boxes) == 0:
  25. return None
  26. boxes = self.expand_points_num(boxes)
  27. boxes = np.array(boxes, dtype=np.float32)
  28. txt_tags = np.array(txt_tags, dtype=np.bool)
  29. # 2. 得到文字、box等信息
  30. data['polys'] = boxes
  31. data['texts'] = txts
  32. data['ignore_tags'] = txt_tags
  33. return data

运行下述代码观察DetLabelEncode类解码标签前后的对比。

  1. # 从PaddleOCR中import DetLabelEncode
  2. from ppocr.data.imaug.label_ops import DetLabelEncode
  3. # 1. 声明标签解码的类
  4. decode_label = DetLabelEncode()
  5. # 2. 打印解码前的标签
  6. print("The label before decode are: ", data['label'])
  7. # 3. 标签解码
  8. data = decode_label(data)
  9. print("\n")
  10. # 4. 打印解码后的标签
  11. print("The polygon after decode are: ", data['polys'])
  12. print("The text after decode are: ", data['texts'])
  1. The label before decode are: [{"transcription": "###", "points": [[427, 293], [469, 293], [468, 315], [425, 314]]}, {"transcription": "###", "points": [[480, 291], [651, 289], [650, 311], [479, 313]]}, {"transcription": "Ave", "points": [[655, 287], [698, 287], [696, 309], [652, 309]]}, {"transcription": "West", "points": [[701, 285], [759, 285], [759, 308], [701, 308]]}, {"transcription": "YOU", "points": [[1044, 531], [1074, 536], [1076, 585], [1046, 579]]}, {"transcription": "CAN", "points": [[1077, 535], [1114, 539], [1117, 595], [1079, 585]]}, {"transcription": "PAY", "points": [[1119, 539], [1160, 543], [1158, 601], [1120, 593]]}, {"transcription": "LESS?", "points": [[1164, 542], [1252, 545], [1253, 624], [1166, 602]]}, {"transcription": "Singapore's", "points": [[1032, 177], [1185, 73], [1191, 143], [1038, 223]]}, {"transcription": "no.1", "points": [[1190, 73], [1270, 19], [1278, 91], [1194, 133]]}]
  2. The polygon after decode are: [[[ 427. 293.]
  3. [ 469. 293.]
  4. [ 468. 315.]
  5. [ 425. 314.]]
  6. [[ 480. 291.]
  7. [ 651. 289.]
  8. [ 650. 311.]
  9. [ 479. 313.]]
  10. [[ 655. 287.]
  11. [ 698. 287.]
  12. [ 696. 309.]
  13. [ 652. 309.]]
  14. [[ 701. 285.]
  15. [ 759. 285.]
  16. [ 759. 308.]
  17. [ 701. 308.]]
  18. [[1044. 531.]
  19. [1074. 536.]
  20. [1076. 585.]
  21. [1046. 579.]]
  22. [[1077. 535.]
  23. [1114. 539.]
  24. [1117. 595.]
  25. [1079. 585.]]
  26. [[1119. 539.]
  27. [1160. 543.]
  28. [1158. 601.]
  29. [1120. 593.]]
  30. [[1164. 542.]
  31. [1252. 545.]
  32. [1253. 624.]
  33. [1166. 602.]]
  34. [[1032. 177.]
  35. [1185. 73.]
  36. [1191. 143.]
  37. [1038. 223.]]
  38. [[1190. 73.]
  39. [1270. 19.]
  40. [1278. 91.]
  41. [1194. 133.]]]
  42. The text after decode are: ['###', '###', 'Ave', 'West', 'YOU', 'CAN', 'PAY', 'LESS?', "Singapore's", 'no.1']
  43. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  44. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  45. 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  46. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  47. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  48. 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)

基础数据增广

数据增广是提高模型训练精度,增加模型泛化性的常用方法,文本检测常用的数据增广包括随机水平翻转、随机旋转、随机缩放以及随机裁剪等等。

随机水平翻转、随机旋转、随机缩放的代码实现参考代码。随机裁剪的数据增广代码实现参考代码

获取阈值图标签

使用扩张的方式获取算法训练需要的阈值图标签;

  1. import numpy as np
  2. import cv2
  3. np.seterr(divide='ignore', invalid='ignore')
  4. import pyclipper
  5. from shapely.geometry import Polygon
  6. import sys
  7. import warnings
  8. warnings.simplefilter("ignore")
  9. # 计算文本区域阈值图标签类
  10. # 详细实现代码参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/make_border_map.py
  11. class MakeBorderMap(object):
  12. def __init__(self,
  13. shrink_ratio=0.4,
  14. thresh_min=0.3,
  15. thresh_max=0.7,
  16. **kwargs):
  17. self.shrink_ratio = shrink_ratio
  18. self.thresh_min = thresh_min
  19. self.thresh_max = thresh_max
  20. def __call__(self, data):
  21. img = data['image']
  22. text_polys = data['polys']
  23. ignore_tags = data['ignore_tags']
  24. # 1. 生成空模版
  25. canvas = np.zeros(img.shape[:2], dtype=np.float32)
  26. mask = np.zeros(img.shape[:2], dtype=np.float32)
  27. for i in range(len(text_polys)):
  28. if ignore_tags[i]:
  29. continue
  30. # 2. draw_border_map函数根据解码后的box信息计算阈值图标签
  31. self.draw_border_map(text_polys[i], canvas, mask=mask)
  32. canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
  33. data['threshold_map'] = canvas
  34. data['threshold_mask'] = mask
  35. return data
  1. # 从PaddleOCR中import MakeBorderMap
  2. from ppocr.data.imaug.make_border_map import MakeBorderMap
  3. # 1. 声明MakeBorderMap函数
  4. generate_text_border = MakeBorderMap()
  5. # 2. 根据解码后的输入数据计算bordermap信息
  6. data = generate_text_border(data)
  7. # 3. 阈值图可视化
  8. plt.figure(figsize=(10, 10))
  9. plt.imshow(src_img)
  10. text_border_map = data['threshold_map']
  11. plt.figure(figsize=(10, 10))
  12. plt.imshow(text_border_map)
  1. <matplotlib.image.AxesImage at 0x7f6dc25a7310>

2.OCR文本检测实战 - 图19

2.OCR文本检测实战 - 图20

获取概率图标签

使用收缩的方式获取算法训练需要的概率图标签;

  1. import numpy as np
  2. import cv2
  3. from shapely.geometry import Polygon
  4. import pyclipper
  5. # 计算概率图标签
  6. # 详细代码实现参考: https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/make_shrink_map.py
  7. class MakeShrinkMap(object):
  8. r'''
  9. Making binary mask from detection data with ICDAR format.
  10. Typically following the process of class `MakeICDARData`.
  11. '''
  12. def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
  13. self.min_text_size = min_text_size
  14. self.shrink_ratio = shrink_ratio
  15. def __call__(self, data):
  16. image = data['image']
  17. text_polys = data['polys']
  18. ignore_tags = data['ignore_tags']
  19. h, w = image.shape[:2]
  20. # 1. 校验文本检测标签
  21. text_polys, ignore_tags = self.validate_polygons(text_polys,
  22. ignore_tags, h, w)
  23. gt = np.zeros((h, w), dtype=np.float32)
  24. mask = np.ones((h, w), dtype=np.float32)
  25. # 2. 根据文本检测框计算文本区域概率图
  26. for i in range(len(text_polys)):
  27. polygon = text_polys[i]
  28. height = max(polygon[:, 1]) - min(polygon[:, 1])
  29. width = max(polygon[:, 0]) - min(polygon[:, 0])
  30. if ignore_tags[i] or min(height, width) < self.min_text_size:
  31. cv2.fillPoly(mask,
  32. polygon.astype(np.int32)[np.newaxis, :, :], 0)
  33. ignore_tags[i] = True
  34. else:
  35. polygon_shape = Polygon(polygon)
  36. subject = [tuple(l) for l in polygon]
  37. padding = pyclipper.PyclipperOffset()
  38. padding.AddPath(subject, pyclipper.JT_ROUND,
  39. pyclipper.ET_CLOSEDPOLYGON)
  40. shrinked = []
  41. # Increase the shrink ratio every time we get multiple polygon returned back
  42. possible_ratios = np.arange(self.shrink_ratio, 1,
  43. self.shrink_ratio)
  44. np.append(possible_ratios, 1)
  45. # print(possible_ratios)
  46. for ratio in possible_ratios:
  47. # print(f"Change shrink ratio to {ratio}")
  48. distance = polygon_shape.area * (
  49. 1 - np.power(ratio, 2)) / polygon_shape.length
  50. shrinked = padding.Execute(-distance)
  51. if len(shrinked) == 1:
  52. break
  53. if shrinked == []:
  54. cv2.fillPoly(mask,
  55. polygon.astype(np.int32)[np.newaxis, :, :], 0)
  56. ignore_tags[i] = True
  57. continue
  58. for each_shrink in shrinked:
  59. shrink = np.array(each_shrink).reshape(-1, 2)
  60. cv2.fillPoly(gt, [shrink.astype(np.int32)], 1)
  61. data['shrink_map'] = gt
  62. data['shrink_mask'] = mask
  63. return data
  1. # 从 PaddleOCR 中 import MakeShrinkMap
  2. from ppocr.data.imaug.make_shrink_map import MakeShrinkMap
  3. # 1. 声明文本概率图标签生成
  4. generate_shrink_map = MakeShrinkMap()
  5. # 2. 根据解码后的标签计算文本区域概率图
  6. data = generate_shrink_map(data)
  7. # 3. 文本区域概率图可视化
  8. plt.figure(figsize=(10, 10))
  9. plt.imshow(src_img)
  10. text_border_map = data['shrink_map']
  11. plt.figure(figsize=(10, 10))
  12. plt.imshow(text_border_map)
  1. <matplotlib.image.AxesImage at 0x7f6dc24dead0>

2.OCR文本检测实战 - 图21

2.OCR文本检测实战 - 图22

归一化

通过规范化手段,把神经网络每层中任意神经元的输入值分布改变成均值为0,方差为1的标准正太分布,使得最优解的寻优过程明显会变得平缓,训练过程更容易收敛;

  1. # 图像归一化类
  2. class NormalizeImage(object):
  3. """ normalize image such as substract mean, divide std
  4. """
  5. def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
  6. if isinstance(scale, str):
  7. scale = eval(scale)
  8. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  9. # 1. 获得归一化的均值和方差
  10. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  11. std = std if std is not None else [0.229, 0.224, 0.225]
  12. shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
  13. self.mean = np.array(mean).reshape(shape).astype('float32')
  14. self.std = np.array(std).reshape(shape).astype('float32')
  15. def __call__(self, data):
  16. # 2. 从字典中获取图像数据
  17. img = data['image']
  18. from PIL import Image
  19. if isinstance(img, Image.Image):
  20. img = np.array(img)
  21. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  22. # 3. 图像归一化
  23. data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std
  24. return data

通道变换

图像的数据格式为[H, W, C](即高度、宽度和通道数),而神经网络使用的训练数据的格式为[C, H, W],因此需要对图像数据重新排列,例如[224, 224, 3]变为[3, 224, 224];

  1. # 改变图像的通道顺序,HWC to CHW
  2. class ToCHWImage(object):
  3. """ convert hwc image to chw image
  4. """
  5. def __init__(self, **kwargs):
  6. pass
  7. def __call__(self, data):
  8. # 1. 从字典中获取图像数据
  9. img = data['image']
  10. from PIL import Image
  11. if isinstance(img, Image.Image):
  12. img = np.array(img)
  13. # 2. 通过转置改变图像的通道顺序
  14. data['image'] = img.transpose((2, 0, 1))
  15. return data
  16. # 1. 声明通道变换类
  17. transpose = ToCHWImage()
  18. # 2. 打印变换前的图像
  19. print("The shape of image before transpose", data['image'].shape)
  20. # 3. 图像通道变换
  21. data = transpose(data)
  22. # 4. 打印通向通道变换后的图像
  23. print("The shape of image after transpose", data['image'].shape)
  1. The shape of image before transpose (720, 1280, 3)
  2. The shape of image after transpose (3, 720, 1280)

3.3 构建数据读取器

上面的代码仅展示了读取一张图片和预处理的方法,在实际模型训练时,多采用批量数据读取处理的方式。

本节采用PaddlePaddle中的DatasetDatasetLoader API构建数据读取器。

  1. # dataloader构建详细代码参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/simple_dataset.py
  2. import numpy as np
  3. import os
  4. import random
  5. from paddle.io import Dataset
  6. def transform(data, ops=None):
  7. """ transform """
  8. if ops is None:
  9. ops = []
  10. for op in ops:
  11. data = op(data)
  12. if data is None:
  13. return None
  14. return data
  15. def create_operators(op_param_list, global_config=None):
  16. """
  17. create operators based on the config
  18. Args:
  19. params(list): a dict list, used to create some operators
  20. """
  21. assert isinstance(op_param_list, list), ('operator config should be a list')
  22. ops = []
  23. for operator in op_param_list:
  24. assert isinstance(operator,
  25. dict) and len(operator) == 1, "yaml format error"
  26. op_name = list(operator)[0]
  27. param = {} if operator[op_name] is None else operator[op_name]
  28. if global_config is not None:
  29. param.update(global_config)
  30. op = eval(op_name)(**param)
  31. ops.append(op)
  32. return ops
  33. class SimpleDataSet(Dataset):
  34. def __init__(self, mode, label_file, data_dir, seed=None):
  35. super(SimpleDataSet, self).__init__()
  36. # 标注文件中,使用'\t'作为分隔符区分图片名称与标签
  37. self.delimiter = '\t'
  38. # 数据集路径
  39. self.data_dir = data_dir
  40. # 随机数种子
  41. self.seed = seed
  42. # 获取所有数据,以列表形式返回
  43. self.data_lines = self.get_image_info_list(label_file)
  44. # 新建列表存放数据索引
  45. self.data_idx_order_list = list(range(len(self.data_lines)))
  46. self.mode = mode
  47. # 如果是训练过程,将数据集进行随机打乱
  48. if self.mode.lower() == "train":
  49. self.shuffle_data_random()
  50. def get_image_info_list(self, label_file):
  51. # 获取标签文件中的所有数据
  52. with open(label_file, "rb") as f:
  53. lines = f.readlines()
  54. return lines
  55. def shuffle_data_random(self):
  56. #随机打乱数据
  57. random.seed(self.seed)
  58. random.shuffle(self.data_lines)
  59. return
  60. def __getitem__(self, idx):
  61. # 1. 获取索引为idx的数据
  62. file_idx = self.data_idx_order_list[idx]
  63. data_line = self.data_lines[file_idx]
  64. try:
  65. # 2. 获取图片名称以及标签
  66. data_line = data_line.decode('utf-8')
  67. substr = data_line.strip("\n").split(self.delimiter)
  68. file_name = substr[0]
  69. label = substr[1]
  70. # 3. 获取图片路径
  71. img_path = os.path.join(self.data_dir, file_name)
  72. data = {'img_path': img_path, 'label': label}
  73. if not os.path.exists(img_path):
  74. raise Exception("{} does not exist!".format(img_path))
  75. # 4. 读取图片并进行预处理
  76. with open(data['img_path'], 'rb') as f:
  77. img = f.read()
  78. data['image'] = img
  79. # 5. 完成数据增强操作
  80. outs = transform(data, self.mode.lower())
  81. # 6. 如果当前数据读取失败,重新随机读取一个新数据
  82. except Exception as e:
  83. outs = None
  84. if outs is None:
  85. return self.__getitem__(np.random.randint(self.__len__()))
  86. return outs
  87. def __len__(self):
  88. # 返回数据集的大小
  89. return len(self.data_idx_order_list)

PaddlePaddle的Dataloader API中可以使用多进程数据读取,并可以自由设置线程数量。多线程数据读取可以加快数据处理速度和模型训练速度,多线程读取实现代码如下:

  1. from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
  2. def build_dataloader(mode, label_file, data_dir, batch_size, drop_last, shuffle, num_workers, seed=None):
  3. # 创建数据读取类
  4. dataset = SimpleDataSet(mode, label_file, data_dir, seed)
  5. # 定义 batch_sampler
  6. batch_sampler = BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
  7. # 使用paddle.io.DataLoader创建数据读取器,并设置batchsize,进程数量num_workers等参数
  8. data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, num_workers=num_workers, return_list=True, use_shared_memory=False)
  9. return data_loader
  1. ic15_data_path = "/home/aistudio/data/data96799/icdar2015/text_localization/"
  2. train_data_label = "/home/aistudio/data/data96799/icdar2015/text_localization/train_icdar2015_label.txt"
  3. eval_data_label = "/home/aistudio/data/data96799/icdar2015/text_localization/test_icdar2015_label.txt"
  4. # 定义训练集数据读取器,进程数设置为8
  5. train_dataloader = build_dataloader('Train', train_data_label, ic15_data_path, batch_size=8, drop_last=False, shuffle=True, num_workers=0)
  6. # 定义验证集数据读取器
  7. eval_dataloader = build_dataloader('Eval', eval_data_label, ic15_data_path, batch_size=1, drop_last=False, shuffle=False, num_workers=0)

3.4 DB模型后处理

DB head网络的输出形状和原图相同,实际上DB head网络输出的三个通道特征分别为文本区域的概率图、阈值图和二值图。

在训练阶段,3个预测图与真实标签共同完成损失函数的计算以及模型训练;

在预测阶段,只需要使用概率图即可,DB后处理函数根据概率图中文本区域的响应计算出包围文本响应区域的文本框坐标。

由于网络预测的概率图是经过收缩后的结果,所以在后处理步骤中,使用相同的偏移值将预测的多边形区域进行扩张,即可得到最终的文本框。代码实现如下所示。

  1. # https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/postprocess/db_postprocess.py
  2. import numpy as np
  3. import cv2
  4. import paddle
  5. from shapely.geometry import Polygon
  6. import pyclipper
  7. class DBPostProcess(object):
  8. """
  9. The post process for Differentiable Binarization (DB).
  10. """
  11. def __init__(self,
  12. thresh=0.3,
  13. box_thresh=0.7,
  14. max_candidates=1000,
  15. unclip_ratio=2.0,
  16. use_dilation=False,
  17. score_mode="fast",
  18. **kwargs):
  19. # 1. 获取后处理超参数
  20. self.thresh = thresh
  21. self.box_thresh = box_thresh
  22. self.max_candidates = max_candidates
  23. self.unclip_ratio = unclip_ratio
  24. self.min_size = 3
  25. self.score_mode = score_mode
  26. assert score_mode in [
  27. "slow", "fast"
  28. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  29. self.dilation_kernel = None if not use_dilation else np.array(
  30. [[1, 1], [1, 1]])
  31. # DB后处理代码详细实现参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/postprocess/db_postprocess.py
  32. def __call__(self, outs_dict, shape_list):
  33. # 1. 从字典中获取网络预测结果
  34. pred = outs_dict['maps']
  35. if isinstance(pred, paddle.Tensor):
  36. pred = pred.numpy()
  37. pred = pred[:, 0, :, :]
  38. # 2. 大于后处理参数阈值self.thresh的
  39. segmentation = pred > self.thresh
  40. boxes_batch = []
  41. for batch_index in range(pred.shape[0]):
  42. # 3. 获取原图的形状和resize比例
  43. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  44. if self.dilation_kernel is not None:
  45. mask = cv2.dilate(
  46. np.array(segmentation[batch_index]).astype(np.uint8),
  47. self.dilation_kernel)
  48. else:
  49. mask = segmentation[batch_index]
  50. # 4. 使用boxes_from_bitmap函数 完成 从预测的文本概率图中计算得到文本框
  51. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  52. src_w, src_h)
  53. boxes_batch.append({'points': boxes})
  54. return boxes_batch

可以发现每个单词都有一个蓝色的框包围着。这些蓝色的框即是在DB输出的分割结果上做一些后处理得到的。将如下代码添加到PaddleOCR/ppocr/postprocess/db_postprocess.py的177行,可以可视化DB输出的分割图,分割图的可视化结果保存为图像vis_segmentation.png。

  1. _maps = np.array(pred[0, :, :] * 255).astype(np.uint8)
  2. import cv2
  3. cv2.imwrite("vis_segmentation.png", _maps)
  1. # 1. 下载训练好的模型
  2. !wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
  3. !cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
  4. # 2. 执行文本检测预测得到结果
  5. !python tools/infer_det.py -c configs/det/det_mv3_db.yml \
  6. -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy \
  7. Global.infer_img=./doc/imgs_en/img_12.jpg
  8. #PostProcess.unclip_ratio=4.0
  9. # 注:有关PostProcess参数和Global参数介绍与使用参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/doc/doc_ch/config.md
  1. File ‘./pretrain_models/det_mv3_db_v2.0_train.tar already there; not retrieving.
  2. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  3. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  4. 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
  5. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  6. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  7. 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
  8. [2021/12/22 14:56:13] root INFO: Architecture :
  9. [2021/12/22 14:56:13] root INFO: Backbone :
  10. [2021/12/22 14:56:13] root INFO: model_name : large
  11. [2021/12/22 14:56:13] root INFO: name : MobileNetV3
  12. [2021/12/22 14:56:13] root INFO: scale : 0.5
  13. [2021/12/22 14:56:13] root INFO: Head :
  14. [2021/12/22 14:56:13] root INFO: k : 50
  15. [2021/12/22 14:56:13] root INFO: name : DBHead
  16. [2021/12/22 14:56:13] root INFO: Neck :
  17. [2021/12/22 14:56:13] root INFO: name : DBFPN
  18. [2021/12/22 14:56:13] root INFO: out_channels : 256
  19. [2021/12/22 14:56:13] root INFO: Transform : None
  20. [2021/12/22 14:56:13] root INFO: algorithm : DB
  21. [2021/12/22 14:56:13] root INFO: model_type : det
  22. [2021/12/22 14:56:13] root INFO: Eval :
  23. [2021/12/22 14:56:13] root INFO: dataset :
  24. [2021/12/22 14:56:13] root INFO: data_dir : ./train_data/icdar2015/text_localization/
  25. [2021/12/22 14:56:13] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']
  26. [2021/12/22 14:56:13] root INFO: name : SimpleDataSet
  27. [2021/12/22 14:56:13] root INFO: transforms :
  28. [2021/12/22 14:56:13] root INFO: DecodeImage :
  29. [2021/12/22 14:56:13] root INFO: channel_first : False
  30. [2021/12/22 14:56:13] root INFO: img_mode : BGR
  31. [2021/12/22 14:56:13] root INFO: DetLabelEncode : None
  32. [2021/12/22 14:56:13] root INFO: DetResizeForTest :
  33. [2021/12/22 14:56:13] root INFO: image_shape : [736, 1280]
  34. [2021/12/22 14:56:13] root INFO: NormalizeImage :
  35. [2021/12/22 14:56:13] root INFO: mean : [0.485, 0.456, 0.406]
  36. [2021/12/22 14:56:13] root INFO: order : hwc
  37. [2021/12/22 14:56:13] root INFO: scale : 1./255.
  38. [2021/12/22 14:56:13] root INFO: std : [0.229, 0.224, 0.225]
  39. [2021/12/22 14:56:13] root INFO: ToCHWImage : None
  40. [2021/12/22 14:56:13] root INFO: KeepKeys :
  41. [2021/12/22 14:56:13] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
  42. [2021/12/22 14:56:13] root INFO: loader :
  43. [2021/12/22 14:56:13] root INFO: batch_size_per_card : 1
  44. [2021/12/22 14:56:13] root INFO: drop_last : False
  45. [2021/12/22 14:56:13] root INFO: num_workers : 8
  46. [2021/12/22 14:56:13] root INFO: shuffle : False
  47. [2021/12/22 14:56:13] root INFO: use_shared_memory : False
  48. [2021/12/22 14:56:13] root INFO: Global :
  49. [2021/12/22 14:56:13] root INFO: cal_metric_during_train : False
  50. [2021/12/22 14:56:13] root INFO: checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
  51. [2021/12/22 14:56:13] root INFO: debug : False
  52. [2021/12/22 14:56:13] root INFO: distributed : False
  53. [2021/12/22 14:56:13] root INFO: epoch_num : 1200
  54. [2021/12/22 14:56:13] root INFO: eval_batch_step : [0, 2000]
  55. [2021/12/22 14:56:13] root INFO: infer_img : ./doc/imgs_en/img_12.jpg
  56. [2021/12/22 14:56:13] root INFO: log_smooth_window : 20
  57. [2021/12/22 14:56:13] root INFO: pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained
  58. [2021/12/22 14:56:13] root INFO: print_batch_step : 10
  59. [2021/12/22 14:56:13] root INFO: save_epoch_step : 1200
  60. [2021/12/22 14:56:13] root INFO: save_inference_dir : None
  61. [2021/12/22 14:56:13] root INFO: save_model_dir : ./output/db_mv3/
  62. [2021/12/22 14:56:13] root INFO: save_res_path : ./output/det_db/predicts_db.txt
  63. [2021/12/22 14:56:13] root INFO: use_gpu : True
  64. [2021/12/22 14:56:13] root INFO: use_visualdl : False
  65. [2021/12/22 14:56:13] root INFO: Loss :
  66. [2021/12/22 14:56:13] root INFO: alpha : 5
  67. [2021/12/22 14:56:13] root INFO: balance_loss : True
  68. [2021/12/22 14:56:13] root INFO: beta : 10
  69. [2021/12/22 14:56:13] root INFO: main_loss_type : DiceLoss
  70. [2021/12/22 14:56:13] root INFO: name : DBLoss
  71. [2021/12/22 14:56:13] root INFO: ohem_ratio : 3
  72. [2021/12/22 14:56:13] root INFO: Metric :
  73. [2021/12/22 14:56:13] root INFO: main_indicator : hmean
  74. [2021/12/22 14:56:13] root INFO: name : DetMetric
  75. [2021/12/22 14:56:13] root INFO: Optimizer :
  76. [2021/12/22 14:56:13] root INFO: beta1 : 0.9
  77. [2021/12/22 14:56:13] root INFO: beta2 : 0.999
  78. [2021/12/22 14:56:13] root INFO: lr :
  79. [2021/12/22 14:56:13] root INFO: learning_rate : 0.001
  80. [2021/12/22 14:56:13] root INFO: name : Adam
  81. [2021/12/22 14:56:13] root INFO: regularizer :
  82. [2021/12/22 14:56:13] root INFO: factor : 0
  83. [2021/12/22 14:56:13] root INFO: name : L2
  84. [2021/12/22 14:56:13] root INFO: PostProcess :
  85. [2021/12/22 14:56:13] root INFO: box_thresh : 0.6
  86. [2021/12/22 14:56:13] root INFO: max_candidates : 1000
  87. [2021/12/22 14:56:13] root INFO: name : DBPostProcess
  88. [2021/12/22 14:56:13] root INFO: thresh : 0.3
  89. [2021/12/22 14:56:13] root INFO: unclip_ratio : 1.5
  90. [2021/12/22 14:56:13] root INFO: Train :
  91. [2021/12/22 14:56:13] root INFO: dataset :
  92. [2021/12/22 14:56:13] root INFO: data_dir : ./train_data/icdar2015/text_localization/
  93. [2021/12/22 14:56:13] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
  94. [2021/12/22 14:56:13] root INFO: name : SimpleDataSet
  95. [2021/12/22 14:56:13] root INFO: ratio_list : [1.0]
  96. [2021/12/22 14:56:13] root INFO: transforms :
  97. [2021/12/22 14:56:13] root INFO: DecodeImage :
  98. [2021/12/22 14:56:13] root INFO: channel_first : False
  99. [2021/12/22 14:56:13] root INFO: img_mode : BGR
  100. [2021/12/22 14:56:13] root INFO: DetLabelEncode : None
  101. [2021/12/22 14:56:13] root INFO: IaaAugment :
  102. [2021/12/22 14:56:13] root INFO: augmenter_args :
  103. [2021/12/22 14:56:13] root INFO: args :
  104. [2021/12/22 14:56:13] root INFO: p : 0.5
  105. [2021/12/22 14:56:13] root INFO: type : Fliplr
  106. [2021/12/22 14:56:13] root INFO: args :
  107. [2021/12/22 14:56:13] root INFO: rotate : [-10, 10]
  108. [2021/12/22 14:56:13] root INFO: type : Affine
  109. [2021/12/22 14:56:13] root INFO: args :
  110. [2021/12/22 14:56:13] root INFO: size : [0.5, 3]
  111. [2021/12/22 14:56:13] root INFO: type : Resize
  112. [2021/12/22 14:56:13] root INFO: EastRandomCropData :
  113. [2021/12/22 14:56:13] root INFO: keep_ratio : True
  114. [2021/12/22 14:56:13] root INFO: max_tries : 50
  115. [2021/12/22 14:56:13] root INFO: size : [640, 640]
  116. [2021/12/22 14:56:13] root INFO: MakeBorderMap :
  117. [2021/12/22 14:56:13] root INFO: shrink_ratio : 0.4
  118. [2021/12/22 14:56:13] root INFO: thresh_max : 0.7
  119. [2021/12/22 14:56:13] root INFO: thresh_min : 0.3
  120. [2021/12/22 14:56:13] root INFO: MakeShrinkMap :
  121. [2021/12/22 14:56:13] root INFO: min_text_size : 8
  122. [2021/12/22 14:56:13] root INFO: shrink_ratio : 0.4
  123. [2021/12/22 14:56:13] root INFO: NormalizeImage :
  124. [2021/12/22 14:56:13] root INFO: mean : [0.485, 0.456, 0.406]
  125. [2021/12/22 14:56:13] root INFO: order : hwc
  126. [2021/12/22 14:56:13] root INFO: scale : 1./255.
  127. [2021/12/22 14:56:13] root INFO: std : [0.229, 0.224, 0.225]
  128. [2021/12/22 14:56:13] root INFO: ToCHWImage : None
  129. [2021/12/22 14:56:13] root INFO: KeepKeys :
  130. [2021/12/22 14:56:13] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
  131. [2021/12/22 14:56:13] root INFO: loader :
  132. [2021/12/22 14:56:13] root INFO: batch_size_per_card : 16
  133. [2021/12/22 14:56:13] root INFO: drop_last : False
  134. [2021/12/22 14:56:13] root INFO: num_workers : 8
  135. [2021/12/22 14:56:13] root INFO: shuffle : True
  136. [2021/12/22 14:56:13] root INFO: use_shared_memory : False
  137. [2021/12/22 14:56:13] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
  138. W1222 14:56:13.651367 1415 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
  139. W1222 14:56:13.655743 1415 device_context.cc:465] device: 0, cuDNN Version: 7.6.
  140. [2021/12/22 14:56:16] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
  141. [2021/12/22 14:56:16] root INFO: infer_img: ./doc/imgs_en/img_12.jpg
  142. [2021/12/22 14:56:17] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg
  143. [2021/12/22 14:56:17] root INFO: success!

可视化预测模型预测的文本概率图,以及最终预测文本框结果。

img = Image.open('./output/det_db/det_results/img_12.jpg')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(10, 10))
plt.imshow(img)

img = Image.open('./vis_segmentation.png')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(10, 10))
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f6e7a6ee350>

2.OCR文本检测实战 - 图23

2.OCR文本检测实战 - 图24

从可视化结果中可以发现DB的输出结果是文本区域的二值图,属于文本区域的响应更高,非文本的背景区域响应值低。DB的后处理即是求这些响应区域的最小包围框,进而得到每个文本区域的坐标。
另外,通过修改后处理参数可以调整文本框的大小,或者过滤检测效果差的文本框。

DB后处理有四个参数,分别是:

  • thresh: DBPostProcess中分割图进行二值化的阈值,默认值为0.3
  • box_thresh: DBPostProcess中对输出框进行过滤的阈值,低于此阈值的框不会输出
  • unclip_ratio: DBPostProcess中对文本框进行放大的比例
  • max_candidates: DBPostProcess中输出的最大文本框数量,默认1000

# 3. 增大DB后处理的参数unlip_ratio为4.0,默认为1.5,改变输出的文本框大小,参数执行文本检测预测得到结果
!python tools/infer_det.py -c configs/det/det_mv3_db.yml \
                           -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy \
                              Global.infer_img=./doc/imgs_en/img_12.jpg \
                              PostProcess.unclip_ratio=4.0
# 注:有关PostProcess参数和Global参数介绍与使用参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/doc/doc_ch/config.md
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
[2021/12/22 14:58:09] root INFO: Architecture : 
[2021/12/22 14:58:09] root INFO:     Backbone : 
[2021/12/22 14:58:09] root INFO:         model_name : large
[2021/12/22 14:58:09] root INFO:         name : MobileNetV3
[2021/12/22 14:58:09] root INFO:         scale : 0.5
[2021/12/22 14:58:09] root INFO:     Head : 
[2021/12/22 14:58:09] root INFO:         k : 50
[2021/12/22 14:58:09] root INFO:         name : DBHead
[2021/12/22 14:58:09] root INFO:     Neck : 
[2021/12/22 14:58:09] root INFO:         name : DBFPN
[2021/12/22 14:58:09] root INFO:         out_channels : 256
[2021/12/22 14:58:09] root INFO:     Transform : None
[2021/12/22 14:58:09] root INFO:     algorithm : DB
[2021/12/22 14:58:09] root INFO:     model_type : det
[2021/12/22 14:58:09] root INFO: Eval : 
[2021/12/22 14:58:09] root INFO:     dataset : 
[2021/12/22 14:58:09] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/22 14:58:09] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']
[2021/12/22 14:58:09] root INFO:         name : SimpleDataSet
[2021/12/22 14:58:09] root INFO:         transforms : 
[2021/12/22 14:58:09] root INFO:             DecodeImage : 
[2021/12/22 14:58:09] root INFO:                 channel_first : False
[2021/12/22 14:58:09] root INFO:                 img_mode : BGR
[2021/12/22 14:58:09] root INFO:             DetLabelEncode : None
[2021/12/22 14:58:09] root INFO:             DetResizeForTest : 
[2021/12/22 14:58:09] root INFO:                 image_shape : [736, 1280]
[2021/12/22 14:58:09] root INFO:             NormalizeImage : 
[2021/12/22 14:58:09] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/22 14:58:09] root INFO:                 order : hwc
[2021/12/22 14:58:09] root INFO:                 scale : 1./255.
[2021/12/22 14:58:09] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/22 14:58:09] root INFO:             ToCHWImage : None
[2021/12/22 14:58:09] root INFO:             KeepKeys : 
[2021/12/22 14:58:09] root INFO:                 keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
[2021/12/22 14:58:09] root INFO:     loader : 
[2021/12/22 14:58:09] root INFO:         batch_size_per_card : 1
[2021/12/22 14:58:09] root INFO:         drop_last : False
[2021/12/22 14:58:09] root INFO:         num_workers : 8
[2021/12/22 14:58:09] root INFO:         shuffle : False
[2021/12/22 14:58:09] root INFO:         use_shared_memory : False
[2021/12/22 14:58:09] root INFO: Global : 
[2021/12/22 14:58:09] root INFO:     cal_metric_during_train : False
[2021/12/22 14:58:09] root INFO:     checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
[2021/12/22 14:58:09] root INFO:     debug : False
[2021/12/22 14:58:09] root INFO:     distributed : False
[2021/12/22 14:58:09] root INFO:     epoch_num : 1200
[2021/12/22 14:58:09] root INFO:     eval_batch_step : [0, 2000]
[2021/12/22 14:58:09] root INFO:     infer_img : ./doc/imgs_en/img_12.jpg
[2021/12/22 14:58:09] root INFO:     log_smooth_window : 20
[2021/12/22 14:58:09] root INFO:     pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained
[2021/12/22 14:58:09] root INFO:     print_batch_step : 10
[2021/12/22 14:58:09] root INFO:     save_epoch_step : 1200
[2021/12/22 14:58:09] root INFO:     save_inference_dir : None
[2021/12/22 14:58:09] root INFO:     save_model_dir : ./output/db_mv3/
[2021/12/22 14:58:09] root INFO:     save_res_path : ./output/det_db/predicts_db.txt
[2021/12/22 14:58:09] root INFO:     use_gpu : True
[2021/12/22 14:58:09] root INFO:     use_visualdl : False
[2021/12/22 14:58:09] root INFO: Loss : 
[2021/12/22 14:58:09] root INFO:     alpha : 5
[2021/12/22 14:58:09] root INFO:     balance_loss : True
[2021/12/22 14:58:09] root INFO:     beta : 10
[2021/12/22 14:58:09] root INFO:     main_loss_type : DiceLoss
[2021/12/22 14:58:09] root INFO:     name : DBLoss
[2021/12/22 14:58:09] root INFO:     ohem_ratio : 3
[2021/12/22 14:58:09] root INFO: Metric : 
[2021/12/22 14:58:09] root INFO:     main_indicator : hmean
[2021/12/22 14:58:09] root INFO:     name : DetMetric
[2021/12/22 14:58:09] root INFO: Optimizer : 
[2021/12/22 14:58:09] root INFO:     beta1 : 0.9
[2021/12/22 14:58:09] root INFO:     beta2 : 0.999
[2021/12/22 14:58:09] root INFO:     lr : 
[2021/12/22 14:58:09] root INFO:         learning_rate : 0.001
[2021/12/22 14:58:09] root INFO:     name : Adam
[2021/12/22 14:58:09] root INFO:     regularizer : 
[2021/12/22 14:58:09] root INFO:         factor : 0
[2021/12/22 14:58:09] root INFO:         name : L2
[2021/12/22 14:58:09] root INFO: PostProcess : 
[2021/12/22 14:58:09] root INFO:     box_thresh : 0.6
[2021/12/22 14:58:09] root INFO:     max_candidates : 1000
[2021/12/22 14:58:09] root INFO:     name : DBPostProcess
[2021/12/22 14:58:09] root INFO:     thresh : 0.3
[2021/12/22 14:58:09] root INFO:     unclip_ratio : 4.0
[2021/12/22 14:58:09] root INFO: Train : 
[2021/12/22 14:58:09] root INFO:     dataset : 
[2021/12/22 14:58:09] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/22 14:58:09] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
[2021/12/22 14:58:09] root INFO:         name : SimpleDataSet
[2021/12/22 14:58:09] root INFO:         ratio_list : [1.0]
[2021/12/22 14:58:09] root INFO:         transforms : 
[2021/12/22 14:58:09] root INFO:             DecodeImage : 
[2021/12/22 14:58:09] root INFO:                 channel_first : False
[2021/12/22 14:58:09] root INFO:                 img_mode : BGR
[2021/12/22 14:58:09] root INFO:             DetLabelEncode : None
[2021/12/22 14:58:09] root INFO:             IaaAugment : 
[2021/12/22 14:58:09] root INFO:                 augmenter_args : 
[2021/12/22 14:58:09] root INFO:                     args : 
[2021/12/22 14:58:09] root INFO:                         p : 0.5
[2021/12/22 14:58:09] root INFO:                     type : Fliplr
[2021/12/22 14:58:09] root INFO:                     args : 
[2021/12/22 14:58:09] root INFO:                         rotate : [-10, 10]
[2021/12/22 14:58:09] root INFO:                     type : Affine
[2021/12/22 14:58:09] root INFO:                     args : 
[2021/12/22 14:58:09] root INFO:                         size : [0.5, 3]
[2021/12/22 14:58:09] root INFO:                     type : Resize
[2021/12/22 14:58:09] root INFO:             EastRandomCropData : 
[2021/12/22 14:58:09] root INFO:                 keep_ratio : True
[2021/12/22 14:58:09] root INFO:                 max_tries : 50
[2021/12/22 14:58:09] root INFO:                 size : [640, 640]
[2021/12/22 14:58:09] root INFO:             MakeBorderMap : 
[2021/12/22 14:58:09] root INFO:                 shrink_ratio : 0.4
[2021/12/22 14:58:09] root INFO:                 thresh_max : 0.7
[2021/12/22 14:58:09] root INFO:                 thresh_min : 0.3
[2021/12/22 14:58:09] root INFO:             MakeShrinkMap : 
[2021/12/22 14:58:09] root INFO:                 min_text_size : 8
[2021/12/22 14:58:09] root INFO:                 shrink_ratio : 0.4
[2021/12/22 14:58:09] root INFO:             NormalizeImage : 
[2021/12/22 14:58:09] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/22 14:58:09] root INFO:                 order : hwc
[2021/12/22 14:58:09] root INFO:                 scale : 1./255.
[2021/12/22 14:58:09] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/22 14:58:09] root INFO:             ToCHWImage : None
[2021/12/22 14:58:09] root INFO:             KeepKeys : 
[2021/12/22 14:58:09] root INFO:                 keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
[2021/12/22 14:58:09] root INFO:     loader : 
[2021/12/22 14:58:09] root INFO:         batch_size_per_card : 16
[2021/12/22 14:58:09] root INFO:         drop_last : False
[2021/12/22 14:58:09] root INFO:         num_workers : 8
[2021/12/22 14:58:09] root INFO:         shuffle : True
[2021/12/22 14:58:09] root INFO:         use_shared_memory : False
[2021/12/22 14:58:09] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
W1222 14:58:09.835049  1556 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W1222 14:58:09.839382  1556 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2021/12/22 14:58:13] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
[2021/12/22 14:58:13] root INFO: infer_img: ./doc/imgs_en/img_12.jpg
[2021/12/22 14:58:13] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg
[2021/12/22 14:58:13] root INFO: success!
img = Image.open('./output/det_db/det_results/img_12.jpg')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(10, 10))
plt.imshow(img)

img = Image.open('./vis_segmentation.png')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(10, 10))
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f6e7a6060d0>

2.OCR文本检测实战 - 图25

2.OCR文本检测实战 - 图26

从上述代码的运行结果可以发现,增大DB后处理的unclip_ratio参数之后,预测的文本框明显变大了。因此,当训练结果不符合我们预期时,可以通过调整后处理参数调整文本检测结果。另外,可以尝试调整其他三个参数thresh,box_thresh,max_candidates对比检测结果。

3.5 损失函数定义

由于训练阶段获取了3个预测图,所以在损失函数中,也需要结合这3个预测图与它们对应的真实标签分别构建3部分损失函数。总的损失函数的公式定义如下:

2.OCR文本检测实战 - 图27

其中,2.OCR文本检测实战 - 图28为总的损失,2.OCR文本检测实战 - 图29为概率图损失,在本实验中使用了带 OHEM(online hard example mining) 的 Dice 损失,2.OCR文本检测实战 - 图30为阈值图损失,在本实验中使用了预测值和标签间的2.OCR文本检测实战 - 图31距离,2.OCR文本检测实战 - 图32为文本二值图的损失函数。2.OCR文本检测实战 - 图332.OCR文本检测实战 - 图34为权重系数,本实验中分别将其设为5和10。

三个loss 2.OCR文本检测实战 - 图352.OCR文本检测实战 - 图362.OCR文本检测实战 - 图37分别是Dice Loss、Dice Loss(OHEM)、MaskL1 Loss,接下来分别定义这3个部分:

  • Dice Loss是比较预测的文本二值图和标签之间的相似度,常用于二值图像分割,代码实现参考链接。公式如下:

2.OCR文本检测实战 - 图38

  • Dice Loss(OHEM)是采用带OHEM的Dice Loss,目的是为了改善正负样本不均衡的问题。OHEM为一种特殊的自动采样方式,可以自动的选择难样本进行loss的计算,从而提升模型的训练效果。这里将正负样本的采样比率设为1:3。代码实现参考链接
  • MaskL1 Loss是计算预测的文本阈值图和标签间的2.OCR文本检测实战 - 图39距离。

from paddle import nn
import paddle
from paddle import nn
import paddle.nn.functional as F


# DB损失函数
# 详细代码实现参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/det_db_loss.py
class DBLoss(nn.Layer):
    """
    Differentiable Binarization (DB) Loss Function
    args:
        param (dict): the super paramter for DB Loss
    """

    def __init__(self,
                 balance_loss=True,
                 main_loss_type='DiceLoss',
                 alpha=5,
                 beta=10,
                 ohem_ratio=3,
                 eps=1e-6,
                 **kwargs):
        super(DBLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        # 声明不同的损失函数
        self.dice_loss = DiceLoss(eps=eps)
        self.l1_loss = MaskL1Loss(eps=eps)
        self.bce_loss = BalanceLoss(
            balance_loss=balance_loss,
            main_loss_type=main_loss_type,
            negative_ratio=ohem_ratio)

    def forward(self, predicts, labels):
        predict_maps = predicts['maps']
        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
            1:]
        shrink_maps = predict_maps[:, 0, :, :]
        threshold_maps = predict_maps[:, 1, :, :]
        binary_maps = predict_maps[:, 2, :, :]
        # 1. 针对文本预测概率图,使用二值交叉熵损失函数
        loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
                                         label_shrink_mask)
        # 2. 针对文本预测阈值图使用L1距离损失函数
        loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
                                           label_threshold_mask)
        # 3. 针对文本预测二值图,使用dice loss损失函数
        loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
                                          label_shrink_mask)

        # 4. 不同的损失函数乘上不同的权重
        loss_shrink_maps = self.alpha * loss_shrink_maps
        loss_threshold_maps = self.beta * loss_threshold_maps

        loss_all = loss_shrink_maps + loss_threshold_maps \
                   + loss_binary_maps
        losses = {'loss': loss_all, \
                  "loss_shrink_maps": loss_shrink_maps, \
                  "loss_threshold_maps": loss_threshold_maps, \
                  "loss_binary_maps": loss_binary_maps}
        return losses

3.6 评估指标

考虑到DB后处理检测框多种多样,并不是水平的,本次试验中采用简单计算IOU的方式来评测,计算代码参考icdar Challenges 4的文本检测评测方法

文本检测的计算指标有三个,分别是Precision,Recall和Hmean,三个指标的计算逻辑为:

  1. 创建[n, m]大小的一个矩阵叫做iouMat,其中n为GT(ground truth)box的个数,m为检测到的框数量;其中n,m为除去了文本标定为###的框数量;
  2. 在iouMat中,统计IOU大于阈值0.5的个数,将这个值除以gt个数n得到Recall;
  3. 在iouMat中,统计IOU大于阈值0.5的个数,将这个值除以检测框m的个数得到Precision;
  4. Hmean的指标计算方式同F1-score的计算方式,公式如下:

2.OCR文本检测实战 - 图40

文本检测metric指标计算的核心代码如下所示,完整代码实现参考链接

# 文本检测metric指标计算方式如下:
# 完整代码参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/metrics/det_metric.py
if len(gtPols) > 0 and len(detPols) > 0:
    outputShape = [len(gtPols), len(detPols)]

    # 1. 创建[n, m]大小的矩阵,用于保存计算的IOU
    iouMat = np.empty(outputShape)
    gtRectMat = np.zeros(len(gtPols), np.int8)
    detRectMat = np.zeros(len(detPols), np.int8)
    for gtNum in range(len(gtPols)):
        for detNum in range(len(detPols)):
            pG = gtPols[gtNum]
            pD = detPols[detNum]

            # 2. 计算预测框和GT框之间的IOU
            iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
    for gtNum in range(len(gtPols)):
        for detNum in range(len(detPols)):
            if gtRectMat[gtNum] == 0 and detRectMat[
                    detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:

                # 2.1 统计IOU大于阈值0.5的个数
                if iouMat[gtNum, detNum] > self.iou_constraint:
                    gtRectMat[gtNum] = 1
                    detRectMat[detNum] = 1
                    detMatched += 1
                    pairs.append({'gt': gtNum, 'det': detNum})
                    detMatchedNums.append(detNum)

    # 3. IOU大于阈值0.5的个数除以GT框的个数numGtcare得到recall
    recall = float(detMatched) / numGtCare

    # 4. IOU大于阈值0.5的个数除以预测框的个数numDetcare得到precision
    precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare

    # 5. 通过公式计算得到Hmean指标
    hmean = 0 if (precision + recall) == 0 else 2.0 * \
                                                    precision * recall / (precision + recall)

思考:

  1. 对于下图中的情况,当GT框与预测框的IOU大于0.5,但是却漏检测文本的情况,上述指标计算是否能准确反映模型的精度?
  2. 实验场景中遇到此类问题,该如何优化模型?

image.png

3.7 模型训练

完成数据处理,网络定义和损失函数定义后即可开始训练模型了。

训练基于PaddleOCR训练,采用参数配置的形式,参数文件参考链接,网络结构参数如下:

Architecture:
  model_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model_name: large
  Neck:
    name: DBFPN
    out_channels: 256
  Head:
    name: DBHead
    k: 50

优化器参数如下:

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    learning_rate: 0.001
  regularizer:
    name: 'L2'
    factor: 0

后处理参数如下:

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

完整参数配置文件见det_mv3_db.yml

!mkdir train_data 
!cd train_data && ln -s /home/aistudio/data/data96799/icdar2015  icdar2015
!wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
mkdir: cannot create directory ‘train_data’: File exists
--2021-12-22 15:04:01--  https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
Resolving paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)... 100.67.200.6
Connecting to paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)|100.67.200.6|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16255295 (16M) [application/octet-stream]
Saving to: ‘./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams.2’

MobileNetV3_large_x 100%[===================>]  15.50M  85.0MB/s    in 0.2s    

2021-12-22 15:04:02 (85.0 MB/s) - ‘./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams.2’ saved [16255295/16255295]
!python tools/train.py -c configs/det/det_mv3_db.yml
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
[2021/12/20 21:27:06] root INFO: Architecture : 
[2021/12/20 21:27:06] root INFO:     Backbone : 
[2021/12/20 21:27:06] root INFO:         model_name : large
[2021/12/20 21:27:06] root INFO:         name : MobileNetV3
[2021/12/20 21:27:06] root INFO:         scale : 0.5
[2021/12/20 21:27:06] root INFO:     Head : 
[2021/12/20 21:27:06] root INFO:         k : 50
[2021/12/20 21:27:06] root INFO:         name : DBHead
[2021/12/20 21:27:06] root INFO:     Neck : 
[2021/12/20 21:27:06] root INFO:         name : DBFPN
[2021/12/20 21:27:06] root INFO:         out_channels : 256
[2021/12/20 21:27:06] root INFO:     Transform : None
[2021/12/20 21:27:06] root INFO:     algorithm : DB
[2021/12/20 21:27:06] root INFO:     model_type : det
[2021/12/20 21:27:06] root INFO: Eval : 
[2021/12/20 21:27:06] root INFO:     dataset : 
[2021/12/20 21:27:06] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/20 21:27:06] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']
[2021/12/20 21:27:06] root INFO:         name : SimpleDataSet
[2021/12/20 21:27:06] root INFO:         transforms : 
[2021/12/20 21:27:06] root INFO:             DecodeImage : 
[2021/12/20 21:27:06] root INFO:                 channel_first : False
[2021/12/20 21:27:06] root INFO:                 img_mode : BGR
[2021/12/20 21:27:06] root INFO:             DetLabelEncode : None
[2021/12/20 21:27:06] root INFO:             DetResizeForTest : 
[2021/12/20 21:27:06] root INFO:                 image_shape : [736, 1280]
[2021/12/20 21:27:06] root INFO:             NormalizeImage : 
[2021/12/20 21:27:06] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/20 21:27:06] root INFO:                 order : hwc
[2021/12/20 21:27:06] root INFO:                 scale : 1./255.
[2021/12/20 21:27:06] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/20 21:27:06] root INFO:             ToCHWImage : None
[2021/12/20 21:27:06] root INFO:             KeepKeys : 
[2021/12/20 21:27:06] root INFO:                 keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
[2021/12/20 21:27:06] root INFO:     loader : 
[2021/12/20 21:27:06] root INFO:         batch_size_per_card : 1
[2021/12/20 21:27:06] root INFO:         drop_last : False
[2021/12/20 21:27:06] root INFO:         num_workers : 8
[2021/12/20 21:27:06] root INFO:         shuffle : False
[2021/12/20 21:27:06] root INFO:         use_shared_memory : False
[2021/12/20 21:27:06] root INFO: Global : 
[2021/12/20 21:27:06] root INFO:     cal_metric_during_train : False
[2021/12/20 21:27:06] root INFO:     checkpoints : None
[2021/12/20 21:27:06] root INFO:     debug : False
[2021/12/20 21:27:06] root INFO:     distributed : False
[2021/12/20 21:27:06] root INFO:     epoch_num : 1200
[2021/12/20 21:27:06] root INFO:     eval_batch_step : [0, 2000]
[2021/12/20 21:27:06] root INFO:     infer_img : doc/imgs_en/img_10.jpg
[2021/12/20 21:27:06] root INFO:     log_smooth_window : 20
[2021/12/20 21:27:06] root INFO:     pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained
[2021/12/20 21:27:06] root INFO:     print_batch_step : 10
[2021/12/20 21:27:06] root INFO:     save_epoch_step : 1200
[2021/12/20 21:27:06] root INFO:     save_inference_dir : None
[2021/12/20 21:27:06] root INFO:     save_model_dir : ./output/db_mv3/
[2021/12/20 21:27:06] root INFO:     save_res_path : ./output/det_db/predicts_db.txt
[2021/12/20 21:27:06] root INFO:     use_gpu : True
[2021/12/20 21:27:06] root INFO:     use_visualdl : False
[2021/12/20 21:27:06] root INFO: Loss : 
[2021/12/20 21:27:06] root INFO:     alpha : 5
[2021/12/20 21:27:06] root INFO:     balance_loss : True
[2021/12/20 21:27:06] root INFO:     beta : 10
[2021/12/20 21:27:06] root INFO:     main_loss_type : DiceLoss
[2021/12/20 21:27:06] root INFO:     name : DBLoss
[2021/12/20 21:27:06] root INFO:     ohem_ratio : 3
[2021/12/20 21:27:06] root INFO: Metric : 
[2021/12/20 21:27:06] root INFO:     main_indicator : hmean
[2021/12/20 21:27:06] root INFO:     name : DetMetric
[2021/12/20 21:27:06] root INFO: Optimizer : 
[2021/12/20 21:27:06] root INFO:     beta1 : 0.9
[2021/12/20 21:27:06] root INFO:     beta2 : 0.999
[2021/12/20 21:27:06] root INFO:     lr : 
[2021/12/20 21:27:06] root INFO:         learning_rate : 0.001
[2021/12/20 21:27:06] root INFO:     name : Adam
[2021/12/20 21:27:06] root INFO:     regularizer : 
[2021/12/20 21:27:06] root INFO:         factor : 0
[2021/12/20 21:27:06] root INFO:         name : L2
[2021/12/20 21:27:06] root INFO: PostProcess : 
[2021/12/20 21:27:06] root INFO:     box_thresh : 0.6
[2021/12/20 21:27:06] root INFO:     max_candidates : 1000
[2021/12/20 21:27:06] root INFO:     name : DBPostProcess
[2021/12/20 21:27:06] root INFO:     thresh : 0.3
[2021/12/20 21:27:06] root INFO:     unclip_ratio : 1.5
[2021/12/20 21:27:06] root INFO: Train : 
[2021/12/20 21:27:06] root INFO:     dataset : 
[2021/12/20 21:27:06] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/20 21:27:06] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
[2021/12/20 21:27:06] root INFO:         name : SimpleDataSet
[2021/12/20 21:27:06] root INFO:         ratio_list : [1.0]
[2021/12/20 21:27:06] root INFO:         transforms : 
[2021/12/20 21:27:06] root INFO:             DecodeImage : 
[2021/12/20 21:27:06] root INFO:                 channel_first : False
[2021/12/20 21:27:06] root INFO:                 img_mode : BGR
[2021/12/20 21:27:06] root INFO:             DetLabelEncode : None
[2021/12/20 21:27:06] root INFO:             IaaAugment : 
[2021/12/20 21:27:06] root INFO:                 augmenter_args : 
[2021/12/20 21:27:06] root INFO:                     args : 
[2021/12/20 21:27:06] root INFO:                         p : 0.5
[2021/12/20 21:27:06] root INFO:                     type : Fliplr
[2021/12/20 21:27:06] root INFO:                     args : 
[2021/12/20 21:27:06] root INFO:                         rotate : [-10, 10]
[2021/12/20 21:27:06] root INFO:                     type : Affine
[2021/12/20 21:27:06] root INFO:                     args : 
[2021/12/20 21:27:06] root INFO:                         size : [0.5, 3]
[2021/12/20 21:27:06] root INFO:                     type : Resize
[2021/12/20 21:27:06] root INFO:             EastRandomCropData : 
[2021/12/20 21:27:06] root INFO:                 keep_ratio : True
[2021/12/20 21:27:06] root INFO:                 max_tries : 50
[2021/12/20 21:27:06] root INFO:                 size : [640, 640]
[2021/12/20 21:27:06] root INFO:             MakeBorderMap : 
[2021/12/20 21:27:06] root INFO:                 shrink_ratio : 0.4
[2021/12/20 21:27:06] root INFO:                 thresh_max : 0.7
[2021/12/20 21:27:06] root INFO:                 thresh_min : 0.3
[2021/12/20 21:27:06] root INFO:             MakeShrinkMap : 
[2021/12/20 21:27:06] root INFO:                 min_text_size : 8
[2021/12/20 21:27:06] root INFO:                 shrink_ratio : 0.4
[2021/12/20 21:27:06] root INFO:             NormalizeImage : 
[2021/12/20 21:27:06] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/20 21:27:06] root INFO:                 order : hwc
[2021/12/20 21:27:06] root INFO:                 scale : 1./255.
[2021/12/20 21:27:06] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/20 21:27:06] root INFO:             ToCHWImage : None
[2021/12/20 21:27:06] root INFO:             KeepKeys : 
[2021/12/20 21:27:06] root INFO:                 keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
[2021/12/20 21:27:06] root INFO:     loader : 
[2021/12/20 21:27:06] root INFO:         batch_size_per_card : 16
[2021/12/20 21:27:06] root INFO:         drop_last : False
[2021/12/20 21:27:06] root INFO:         num_workers : 8
[2021/12/20 21:27:06] root INFO:         shuffle : True
[2021/12/20 21:27:06] root INFO:         use_shared_memory : False
[2021/12/20 21:27:06] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
[2021/12/20 21:27:06] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
[2021/12/20 21:27:06] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']
W1220 21:27:06.898311  5756 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1220 21:27:06.902971  5756 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2021/12/20 21:27:11] root INFO: The shape of model params neck.in2_conv.weight [256, 16, 1, 1] not matched with loaded params last_conv.weight [1280, 480, 1, 1] !
[2021/12/20 21:27:11] root INFO: The shape of model params neck.in3_conv.weight [256, 24, 1, 1] not matched with loaded params out.weight [1280, 1000] !
[2021/12/20 21:27:11] root INFO: The shape of model params neck.in4_conv.weight [256, 56, 1, 1] not matched with loaded params out.bias [1000] !
[2021/12/20 21:27:11] root INFO: loaded pretrained_model successful from ./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams
[2021/12/20 21:27:11] root INFO: train dataloader has 63 iters
[2021/12/20 21:27:11] root INFO: valid dataloader has 500 iters
[2021/12/20 21:27:11] root INFO: During the training process, after the 0th iteration, an evaluation is run every 2000 iterations
[2021/12/20 21:27:11] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
[2021/12/20 21:27:30] root INFO: epoch: [1/1200], iter: 10, lr: 0.001000, loss: 7.921230, loss_shrink_maps: 4.884024, loss_threshold_maps: 2.052891, loss_binary_maps: 0.978388, reader_cost: 1.09728 s, batch_cost: 1.82032 s, samples: 176, ips: 9.66863
[2021/12/20 21:27:37] root INFO: epoch: [1/1200], iter: 20, lr: 0.001000, loss: 6.997892, loss_shrink_maps: 4.848688, loss_threshold_maps: 1.204733, loss_binary_maps: 0.969754, reader_cost: 0.04498 s, batch_cost: 0.65158 s, samples: 160, ips: 24.55583
[2021/12/20 21:27:44] root INFO: epoch: [1/1200], iter: 30, lr: 0.001000, loss: 6.801436, loss_shrink_maps: 4.775781, loss_threshold_maps: 1.086550, loss_binary_maps: 0.949920, reader_cost: 0.06301 s, batch_cost: 0.69976 s, samples: 160, ips: 22.86486
[2021/12/20 21:27:51] root INFO: epoch: [1/1200], iter: 40, lr: 0.001000, loss: 6.507986, loss_shrink_maps: 4.605722, loss_threshold_maps: 1.032738, loss_binary_maps: 0.889316, reader_cost: 0.04577 s, batch_cost: 0.62410 s, samples: 160, ips: 25.63695
[2021/12/20 21:27:58] root INFO: epoch: [1/1200], iter: 50, lr: 0.001000, loss: 6.257969, loss_shrink_maps: 4.403828, loss_threshold_maps: 1.019095, loss_binary_maps: 0.798222, reader_cost: 0.00828 s, batch_cost: 0.65507 s, samples: 160, ips: 24.42477
^C
main proc 5778 exit, kill process group 5756
main proc 5779 exit, kill process group 5756
main proc 5776 exit, kill process group 5756
main proc 5775 exit, kill process group 5756
main proc 5777 exit, kill process group 5756
main proc 5774 exit, kill process group 5756
main proc 5773 exit, kill process group 5756
main proc 5772 exit, kill process group 5756

网络训练后的模型默认保存在PaddleOCR/output/db_mv3/目录下,如果想更换保存目录可以在训练时设置参数Global.save_model_dir,比如:

# 设置参数文件里的Global.save_model_dir可以更改模型保存目录
python tools/train.py -c configs/det/det_mv3_db.yml -o Global.save_model_dir="./output/save_db_train/"

3.8 模型评估

训练过程中,默认保存两种模型,一种是latest命名的最新训练的模型,一种是best_accuracy命名的精度最高的模型。接下来使用保存的模型参数评估在测试集上的precision、recall和hmean:

文本检测精度评估代码位于PaddleOCR/ppocr/metrics/det_metric.py中,调用tools/eval.py即可进行对训练好的模型做精度评估。

!python tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3/best_accuracy

3.9 模型预测

训练好模型后,也可以使用保存好的模型,对数据集中的某一张图片或者某个文件夹的图像进行模型推理,观察模型预测效果。

!python tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy Global.infer_img=./doc/imgs_en/img_12.jpg
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
[2021/12/22 15:04:32] root INFO: Architecture : 
[2021/12/22 15:04:32] root INFO:     Backbone : 
[2021/12/22 15:04:32] root INFO:         model_name : large
[2021/12/22 15:04:32] root INFO:         name : MobileNetV3
[2021/12/22 15:04:32] root INFO:         scale : 0.5
[2021/12/22 15:04:32] root INFO:     Head : 
[2021/12/22 15:04:32] root INFO:         k : 50
[2021/12/22 15:04:32] root INFO:         name : DBHead
[2021/12/22 15:04:32] root INFO:     Neck : 
[2021/12/22 15:04:32] root INFO:         name : DBFPN
[2021/12/22 15:04:32] root INFO:         out_channels : 256
[2021/12/22 15:04:32] root INFO:     Transform : None
[2021/12/22 15:04:32] root INFO:     algorithm : DB
[2021/12/22 15:04:32] root INFO:     model_type : det
[2021/12/22 15:04:32] root INFO: Eval : 
[2021/12/22 15:04:32] root INFO:     dataset : 
[2021/12/22 15:04:32] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/22 15:04:32] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']
[2021/12/22 15:04:32] root INFO:         name : SimpleDataSet
[2021/12/22 15:04:32] root INFO:         transforms : 
[2021/12/22 15:04:32] root INFO:             DecodeImage : 
[2021/12/22 15:04:32] root INFO:                 channel_first : False
[2021/12/22 15:04:32] root INFO:                 img_mode : BGR
[2021/12/22 15:04:32] root INFO:             DetLabelEncode : None
[2021/12/22 15:04:32] root INFO:             DetResizeForTest : 
[2021/12/22 15:04:32] root INFO:                 image_shape : [736, 1280]
[2021/12/22 15:04:32] root INFO:             NormalizeImage : 
[2021/12/22 15:04:32] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/22 15:04:32] root INFO:                 order : hwc
[2021/12/22 15:04:32] root INFO:                 scale : 1./255.
[2021/12/22 15:04:32] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/22 15:04:32] root INFO:             ToCHWImage : None
[2021/12/22 15:04:32] root INFO:             KeepKeys : 
[2021/12/22 15:04:32] root INFO:                 keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
[2021/12/22 15:04:32] root INFO:     loader : 
[2021/12/22 15:04:32] root INFO:         batch_size_per_card : 1
[2021/12/22 15:04:32] root INFO:         drop_last : False
[2021/12/22 15:04:32] root INFO:         num_workers : 8
[2021/12/22 15:04:32] root INFO:         shuffle : False
[2021/12/22 15:04:32] root INFO:         use_shared_memory : False
[2021/12/22 15:04:32] root INFO: Global : 
[2021/12/22 15:04:32] root INFO:     cal_metric_during_train : False
[2021/12/22 15:04:32] root INFO:     checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
[2021/12/22 15:04:32] root INFO:     debug : False
[2021/12/22 15:04:32] root INFO:     distributed : False
[2021/12/22 15:04:32] root INFO:     epoch_num : 1200
[2021/12/22 15:04:32] root INFO:     eval_batch_step : [0, 2000]
[2021/12/22 15:04:32] root INFO:     infer_img : ./doc/imgs_en/img_12.jpg
[2021/12/22 15:04:32] root INFO:     log_smooth_window : 20
[2021/12/22 15:04:32] root INFO:     pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained
[2021/12/22 15:04:32] root INFO:     print_batch_step : 10
[2021/12/22 15:04:32] root INFO:     save_epoch_step : 1200
[2021/12/22 15:04:32] root INFO:     save_inference_dir : None
[2021/12/22 15:04:32] root INFO:     save_model_dir : ./output/db_mv3/
[2021/12/22 15:04:32] root INFO:     save_res_path : ./output/det_db/predicts_db.txt
[2021/12/22 15:04:32] root INFO:     use_gpu : True
[2021/12/22 15:04:32] root INFO:     use_visualdl : False
[2021/12/22 15:04:32] root INFO: Loss : 
[2021/12/22 15:04:32] root INFO:     alpha : 5
[2021/12/22 15:04:32] root INFO:     balance_loss : True
[2021/12/22 15:04:32] root INFO:     beta : 10
[2021/12/22 15:04:32] root INFO:     main_loss_type : DiceLoss
[2021/12/22 15:04:32] root INFO:     name : DBLoss
[2021/12/22 15:04:32] root INFO:     ohem_ratio : 3
[2021/12/22 15:04:32] root INFO: Metric : 
[2021/12/22 15:04:32] root INFO:     main_indicator : hmean
[2021/12/22 15:04:32] root INFO:     name : DetMetric
[2021/12/22 15:04:32] root INFO: Optimizer : 
[2021/12/22 15:04:32] root INFO:     beta1 : 0.9
[2021/12/22 15:04:32] root INFO:     beta2 : 0.999
[2021/12/22 15:04:32] root INFO:     lr : 
[2021/12/22 15:04:32] root INFO:         learning_rate : 0.001
[2021/12/22 15:04:32] root INFO:     name : Adam
[2021/12/22 15:04:32] root INFO:     regularizer : 
[2021/12/22 15:04:32] root INFO:         factor : 0
[2021/12/22 15:04:32] root INFO:         name : L2
[2021/12/22 15:04:32] root INFO: PostProcess : 
[2021/12/22 15:04:32] root INFO:     box_thresh : 0.6
[2021/12/22 15:04:32] root INFO:     max_candidates : 1000
[2021/12/22 15:04:32] root INFO:     name : DBPostProcess
[2021/12/22 15:04:32] root INFO:     thresh : 0.3
[2021/12/22 15:04:32] root INFO:     unclip_ratio : 1.5
[2021/12/22 15:04:32] root INFO: Train : 
[2021/12/22 15:04:32] root INFO:     dataset : 
[2021/12/22 15:04:32] root INFO:         data_dir : ./train_data/icdar2015/text_localization/
[2021/12/22 15:04:32] root INFO:         label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']
[2021/12/22 15:04:32] root INFO:         name : SimpleDataSet
[2021/12/22 15:04:32] root INFO:         ratio_list : [1.0]
[2021/12/22 15:04:32] root INFO:         transforms : 
[2021/12/22 15:04:32] root INFO:             DecodeImage : 
[2021/12/22 15:04:32] root INFO:                 channel_first : False
[2021/12/22 15:04:32] root INFO:                 img_mode : BGR
[2021/12/22 15:04:32] root INFO:             DetLabelEncode : None
[2021/12/22 15:04:32] root INFO:             IaaAugment : 
[2021/12/22 15:04:32] root INFO:                 augmenter_args : 
[2021/12/22 15:04:32] root INFO:                     args : 
[2021/12/22 15:04:32] root INFO:                         p : 0.5
[2021/12/22 15:04:32] root INFO:                     type : Fliplr
[2021/12/22 15:04:32] root INFO:                     args : 
[2021/12/22 15:04:32] root INFO:                         rotate : [-10, 10]
[2021/12/22 15:04:32] root INFO:                     type : Affine
[2021/12/22 15:04:32] root INFO:                     args : 
[2021/12/22 15:04:32] root INFO:                         size : [0.5, 3]
[2021/12/22 15:04:32] root INFO:                     type : Resize
[2021/12/22 15:04:32] root INFO:             EastRandomCropData : 
[2021/12/22 15:04:32] root INFO:                 keep_ratio : True
[2021/12/22 15:04:32] root INFO:                 max_tries : 50
[2021/12/22 15:04:32] root INFO:                 size : [640, 640]
[2021/12/22 15:04:32] root INFO:             MakeBorderMap : 
[2021/12/22 15:04:32] root INFO:                 shrink_ratio : 0.4
[2021/12/22 15:04:32] root INFO:                 thresh_max : 0.7
[2021/12/22 15:04:32] root INFO:                 thresh_min : 0.3
[2021/12/22 15:04:32] root INFO:             MakeShrinkMap : 
[2021/12/22 15:04:32] root INFO:                 min_text_size : 8
[2021/12/22 15:04:32] root INFO:                 shrink_ratio : 0.4
[2021/12/22 15:04:32] root INFO:             NormalizeImage : 
[2021/12/22 15:04:32] root INFO:                 mean : [0.485, 0.456, 0.406]
[2021/12/22 15:04:32] root INFO:                 order : hwc
[2021/12/22 15:04:32] root INFO:                 scale : 1./255.
[2021/12/22 15:04:32] root INFO:                 std : [0.229, 0.224, 0.225]
[2021/12/22 15:04:32] root INFO:             ToCHWImage : None
[2021/12/22 15:04:32] root INFO:             KeepKeys : 
[2021/12/22 15:04:32] root INFO:                 keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
[2021/12/22 15:04:32] root INFO:     loader : 
[2021/12/22 15:04:32] root INFO:         batch_size_per_card : 16
[2021/12/22 15:04:32] root INFO:         drop_last : False
[2021/12/22 15:04:32] root INFO:         num_workers : 8
[2021/12/22 15:04:32] root INFO:         shuffle : True
[2021/12/22 15:04:32] root INFO:         use_shared_memory : False
[2021/12/22 15:04:32] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
W1222 15:04:32.031893  1854 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W1222 15:04:32.036085  1854 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2021/12/22 15:04:35] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
[2021/12/22 15:04:35] root INFO: infer_img: ./doc/imgs_en/img_12.jpg
[2021/12/22 15:04:35] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg
[2021/12/22 15:04:35] root INFO: success!

预测后的图像默认保存在./output/det_db/det_results/目录下,使用PIL库可视化结果如下:

import matplotlib.pyplot as plt
# 在notebook中使用matplotlib.pyplot绘图时,需要添加该命令进行显示
%matplotlib inline
from PIL import Image
import numpy as np

img = Image.open('./output/det_db/det_results/img_12.jpg')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(20, 20))
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f20c93d7050>

image.png

4. 总结

本节介绍了PaddleOCR文本检测模型的快速使用方法,并且以DB算法为例,介绍了从数据处理到完成文本检测算法训练的实现过程。下一节将介绍文本识别算法的相关内容。

FAQ

  1. 遇到如下图文字漏检测部分,该如何处理?image.png

上述问题表现检测了一部分文字,但是文本预测框和GT框的IOU大于阈值0.5,检测指标无法正常反馈出来,如果此类结果较多,建议增大IOU阈值。另外,漏检测的本质原因在于,一部分文字的特征没有响应,归根结底是网络没有学习到漏检测部分文字的特征。建议具体问题具体分析,可视化预测结果分析漏检测的原因,是否是因为光照,形变,文字较长等因素导致的,然后针对性的使用数据增强、调整网络、或者调整后处理等方法优化检测结果。
更多文本检测FAQ内容,参考下一节内容。

作业

简答题:

  1. 根据DB Backbone和FPN的输出特征图的大小,判断DB的输入图像高度和宽度需要是_的倍数?
    A: 32, B: 64

实验题:

  1. 使用DB算法配置文件configs/det/det_mv3_db.yml在数据集det_data_lesson_demo.tar上训练文本检测模型,并调优实验精度。

image.png