文本识别实战
上一章理论部分,介绍了文本识别领域的主要方法,其中CRNN是较早被提出也是目前工业界应用较多的方法。本章将详细介绍如何基于PaddleOCR完成CRNN文本识别模型的搭建、训练、评估和预测。数据集采用 icdar 2015,其中训练集有4468张,测试集有2077张。
通过本章的学习,你可以掌握:
- 如何使用paddleocr whl 包快速完成文本识别预测
- CRNN的基本原理和网络结构
- 模型训练的必须步骤和调参方式
- 使用自定义的数据集训练网络
1. 快速体验
1.1 安装相关的依赖及whl包
首先确认安装了 paddle 以及 paddleocr,如果已经安装过,忽略该步骤。
# 安装 PaddlePaddle GPU 版本
!pip install paddlepaddle-gpu
# 安装 paddleocr whl包
! pip install -U pip
! pip install paddleocr
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: paddlepaddle-gpu in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.1.2.post101)
Requirement already satisfied: astor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.8.1)
Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (2.22.0)
Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (3.14.0)
Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (7.1.2)
Requirement already satisfied: gast<=0.4.0,>=0.3.3; platform_system != "Windows" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.3.3)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.15.0)
Requirement already satisfied: numpy>=1.13; python_version >= "3.5" and platform_system != "Windows" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.20.3)
Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (4.4.2)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (1.25.6)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2019.9.11)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2.8)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (3.0.4)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting pip
[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)
[K |████████████████████████████████| 1.7MB 6.6MB/s eta 0:00:01
[?25hInstalling collected packages: pip
Found existing installation: pip 19.2.3
Uninstalling pip-19.2.3:
Successfully uninstalled pip-19.2.3
Successfully installed pip-21.3.1
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting paddleocr
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/b6/5486e674ce096667dff247b58bf0fb789c2ce17a10e546c2686a2bb07aec/paddleocr-2.3.0.2-py3-none-any.whl (250 kB)
|████████████████████████████████| 250 kB 6.6 MB/s
[?25hCollecting python-Levenshtein
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)
|████████████████████████████████| 50 kB 11.1 MB/s
[?25h Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting lmdb
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)
|████████████████████████████████| 299 kB 94.1 MB/s
[?25hCollecting pyclipper
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)
|████████████████████████████████| 603 kB 53.1 MB/s
[?25hRequirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (2.2.0)
Requirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (0.29)
Collecting shapely
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)
|████████████████████████████████| 1.1 MB 71.4 MB/s
[?25hCollecting premailer
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)
Collecting opencv-contrib-python==4.4.0.46
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)
|████████████████████████████████| 55.7 MB 46 kB/s
[?25hCollecting imgaug==0.4.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
|████████████████████████████████| 948 kB 57.0 MB/s
[?25hRequirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (3.0.5)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (4.36.1)
Collecting lxml
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)
|████████████████████████████████| 6.4 MB 56.9 MB/s
[?25hCollecting fasttext==0.9.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)
|████████████████████████████████| 57 kB 9.3 MB/s
[?25h Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (1.20.3)
Collecting scikit-image
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)
|████████████████████████████████| 13.3 MB 48.9 MB/s
[?25hCollecting pybind11>=2.2
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)
Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->paddleocr) (56.2.0)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.15.0)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (4.1.1.26)
Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (7.1.2)
Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.6.3)
Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.6.1)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.2.3)
Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (2.4)
Collecting tifffile>=2019.7.26
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)
|████████████████████████████████| 178 kB 62.3 MB/s
[?25hCollecting PyWavelets>=1.1.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)
|████████████████████████████████| 6.1 MB 7.2 MB/s
[?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (20.9)
Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.4.1)
Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.0.1)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (2.22.0)
Collecting cssselect
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (4.0.0)
Collecting cssutils
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)
|████████████████████████████████| 404 kB 64.5 MB/s
[?25hRequirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.8.53)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.8.2)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.1)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.5)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.7.1.1)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.21.0)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.14.0)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.0.0)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.2.0)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.23)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.6.0)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.6.1)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (2.11.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (0.16.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (7.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (1.1.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2019.3)
Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->paddleocr) (4.4.2)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->paddleocr) (2.4.2)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (0.18.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (2.8.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (0.10.0)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (2.0.1)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.0)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (0.10.0)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.4.10)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (16.7.9)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.4)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (5.1.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2019.9.11)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (1.25.6)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2.8)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddleocr) (1.1.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->paddleocr) (3.6.0)
Building wheels for collected packages: fasttext, python-Levenshtein
Building wheel for fasttext (setup.py) ... [?25ldone
[?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2584472 sha256=83985c4335d673b2e0d15d50d278af15c9ac0ec34207d12db9b9164a5dbd00ff
Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36
Building wheel for python-Levenshtein (setup.py) ... [?25ldone
[?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171682 sha256=e696a5668c43b467a75c288c921d8f4b88a9d7b77b4bb531cb15d083cff3ae48
Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446
Successfully built fasttext python-Levenshtein
Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext, paddleocr
Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 paddleocr-2.3.0.2 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2
1.2 快速预测文字内容
paddleocr whl包会自动下载ppocr轻量级模型作为默认模型
下面展示如何使用whl包进行识别预测:
测试图片:
from paddleocr import PaddleOCR
ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = '/home/aistudio/work/word_19.png'
result = ocr.ocr(img_path, det=False)
for line in result:
print(line)
[2021/12/23 19:06:41] root WARNING: version PP-OCRv2 not support cls models, auto switch to version PP-OCR
download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer/ch_PP-OCRv2_det_infer.tar
0%| | 0.00/3.19M [00:00<?, ?iB/s]100%|██████████| 3.19M/3.19M [00:00<00:00, 8.91MiB/s]
10%|█ | 904k/8.88M [00:00<00:00, 8.93MiB/s]
download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer/ch_PP-OCRv2_rec_infer.tar
100%|██████████| 8.88M/8.88M [00:00<00:00, 32.6MiB/s]
28%|██▊ | 413k/1.45M [00:00<00:00, 3.92MiB/s]
download https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar to /home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar
86%|████████▌ | 1.25M/1.45M [00:00<00:00, 4.05MiB/s]100%|██████████| 1.45M/1.45M [00:00<00:00, 4.10MiB/s]
Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer', det_pse_box_thresh=0.85, det_pse_box_type='box', det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir=None, ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, ocr_version='PP-OCRv2', output='./output/table', precision='fp32', process_id=0, rec=True, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, structure_version='STRUCTURE', table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_onnx=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, vis_font_path='./doc/fonts/simfang.ttf', warmup=True)
[2021/12/23 19:06:45] root WARNING: Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process
('SLOW', 0.8881992)
执行完上述代码块,将返回识别结果和识别置信度
('SLOW', 0.8881992)
至此,你掌握了如何使用 paddleocr whl 包进行预测。./work/
路径下有更多测试图片,可以尝试其他图片结果。
2. 预测原理详解
第一节中 paddleocr 加载训练好的 CRNN 识别模型进行预测,本节将详细介绍 CRNN 的原理及流程。
2.1 所属类别
CRNN 是基于CTC的算法,在理论部分介绍的分类图中,处在如下位置。可以看出CRNN主要用于解决规则文本,基于CTC的算法有较快的预测速度并且很好的适用长文本。因此CRNN是PPOCR选择的中文识别算法。
2.2 算法详解
CRNN 的网络结构体系如下所示,从下往上分别为卷积层、递归层和转录层三部分:
1)backbone:
卷积网络作为底层的骨干网络,用于从输入图像中提取特征序列。由于 conv、max-pooling、elementwise 和激活函数都作用在局部区域上,所以它们是平移不变的。因此,特征映射的每一列对应于原始图像的一个矩形区域(称为感受野),并且这些矩形区域与它们在特征映射上对应的列从左到右的顺序相同。由于CNN需要将输入的图像缩放到固定的尺寸以满足其固定的输入维数,因此它不适合长度变化很大的序列对象。为了更好的支持变长序列,CRNN将backbone最后一层输出的特征向量送到了RNN层,转换为序列特征。
2)neck:
递归层,在卷积网络的基础上,构建递归网络,将图像特征转换为序列特征,预测每个帧的标签分布。RNN具有很强的捕获序列上下文信息的能力。使用上下文线索进行基于图像的序列识别比单独处理每个像素更有效。以场景文本识别为例,宽字符可能需要几个连续的帧来充分描述。此外,有些歧义字符在观察其上下文时更容易区分。其次,RNN可以将误差差分反向传播回卷积层,使网络可以统一训练。第三,RNN能够对任意长度的序列进行操作,解决了文本图片变长的问题。CRNN使用双层LSTM作为递归层,解决了长序列训练过程中的梯度消失和梯度爆炸问题。
3)head:
转录层,通过全连接网络和softmax激活函数,将每帧的预测转换为最终的标签序列。最后使用 CTC Loss 在无需序列对齐的情况下,完成CNN和RNN的联合训练。CTC 有一套特别的合并序列机制,LSTM输出序列后,需要在时序上分类得到预测结果。可能存在多个时间步对应同一个类别,因此需要对相同结果进行合并。为避免合并本身存在的重复字符,CTC 引入了一个 blank 字符插入在重复字符之间。
2.3 代码实现
整个网络结构非常简洁,代码实现也相对简单,可以跟随预测流程依次搭建模块。本节需要完成:数据输入、backbone搭建、neck搭建、head搭建。
【数据输入】
数据送入网络前需要缩放到统一尺寸(3,32,320),并完成归一化处理。这里省略掉训练时需要的数据增强部分,以单张图为例展示预处理的必须步骤(源码位置):
import cv2
import math
import numpy as np
def resize_norm_img(img):
"""
数据缩放和归一化
:param img: 输入图片
"""
# 默认输入尺寸
imgC = 3
imgH = 32
imgW = 320
# 图片的真实高宽
h, w = img.shape[:2]
# 图片真实长宽比
ratio = w / float(h)
# 按比例缩放
if math.ceil(imgH * ratio) > imgW:
# 如大于默认宽度,则宽度为imgW
resized_w = imgW
else:
# 如小于默认宽度则以图片真实宽为准
resized_w = int(math.ceil(imgH * ratio))
# 缩放
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
# 归一化
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
# 对宽度不足的位置,补0
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
# 转置 padding 后的图片用于可视化
draw_img = padding_im.transpose((1,2,0))
return padding_im, draw_img
import matplotlib.pyplot as plt
# 读图
raw_img = cv2.imread("/home/aistudio/work/word_1.png")
plt.figure()
plt.subplot(2,1,1)
# 可视化原图
plt.imshow(raw_img)
# 缩放并归一化
padding_im, draw_img = resize_norm_img(raw_img)
plt.subplot(2,1,2)
# 可视化网络输入图
plt.imshow(draw_img)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
【网络结构】
- backbone
PaddleOCR 使用 MobileNetV3 作为骨干网络,组网顺序与网络结构一致,首先定义网络中的公共模块(源码位置):ConvBNLayer、ResidualUnit、make_divisible
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None):
"""
卷积BN层
:param in_channels: 输入通道数
:param out_channels: 输出通道数
:param kernel_size: 卷积核尺寸
:parma stride: 步长大小
:param padding: 填充大小
:param groups: 二维卷积层的组数
:param if_act: 是否添加激活函数
:param act: 激活函数
"""
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
def forward(self, x):
# conv层
x = self.conv(x)
# batchnorm层
x = self.bn(x)
# 是否使用激活函数
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
return x
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4):
"""
SE模块
:param in_channels: 输入通道数
:param reduction: 通道缩放率
"""
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, inputs):
# 平均池化
outputs = self.avg_pool(inputs)
# 第一个卷积层
outputs = self.conv1(outputs)
# relu激活函数
outputs = F.relu(outputs)
# 第二个卷积层
outputs = self.conv2(outputs)
# hardsigmoid 激活函数
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs
class ResidualUnit(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None):
"""
残差层
:param in_channels: 输入通道数
:param mid_channels: 中间通道数
:param out_channels: 输出通道数
:param kernel_size: 卷积核尺寸
:parma stride: 步长大小
:param use_se: 是否使用se模块
:param act: 激活函数
"""
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act)
if self.if_se:
self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None)
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
def make_divisible(v, divisor=8, min_value=None):
"""
确保被8整除
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
利用公共模块搭建骨干网络
class MobileNetV3(nn.Layer):
def __init__(self,
in_channels=3,
model_name='small',
scale=0.5,
small_stride=None,
disable_se=False,
**kwargs):
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
small_stride = [1, 2, 2, 2]
if model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scales are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hardswish')
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl))
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hardswish')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.pool(x)
return x
至此就完成了骨干网络的定义,可通过 paddle.summary 结构可视化整个网络结构:
# 定义网络输入shape
IMAGE_SHAPE_C = 3
IMAGE_SHAPE_H = 32
IMAGE_SHAPE_W = 320
# 可视化网络结构
paddle.summary(MobileNetV3(),[(1, IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)])
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-1 [[1, 3, 32, 320]] [1, 8, 16, 160] 216
BatchNorm-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 32
ConvBNLayer-1 [[1, 3, 32, 320]] [1, 8, 16, 160] 0
Conv2D-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 64
BatchNorm-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 32
ConvBNLayer-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 0
Conv2D-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 72
BatchNorm-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 32
ConvBNLayer-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 0
AdaptiveAvgPool2D-1 [[1, 8, 16, 160]] [1, 8, 1, 1] 0
Conv2D-4 [[1, 8, 1, 1]] [1, 2, 1, 1] 18
Conv2D-5 [[1, 2, 1, 1]] [1, 8, 1, 1] 24
SEModule-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 0
Conv2D-6 [[1, 8, 16, 160]] [1, 8, 16, 160] 64
BatchNorm-4 [[1, 8, 16, 160]] [1, 8, 16, 160] 32
ConvBNLayer-4 [[1, 8, 16, 160]] [1, 8, 16, 160] 0
ResidualUnit-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 0
Conv2D-7 [[1, 8, 16, 160]] [1, 40, 16, 160] 320
BatchNorm-5 [[1, 40, 16, 160]] [1, 40, 16, 160] 160
ConvBNLayer-5 [[1, 8, 16, 160]] [1, 40, 16, 160] 0
Conv2D-8 [[1, 40, 16, 160]] [1, 40, 8, 160] 360
BatchNorm-6 [[1, 40, 8, 160]] [1, 40, 8, 160] 160
ConvBNLayer-6 [[1, 40, 16, 160]] [1, 40, 8, 160] 0
Conv2D-9 [[1, 40, 8, 160]] [1, 16, 8, 160] 640
BatchNorm-7 [[1, 16, 8, 160]] [1, 16, 8, 160] 64
ConvBNLayer-7 [[1, 40, 8, 160]] [1, 16, 8, 160] 0
ResidualUnit-2 [[1, 8, 16, 160]] [1, 16, 8, 160] 0
Conv2D-10 [[1, 16, 8, 160]] [1, 48, 8, 160] 768
BatchNorm-8 [[1, 48, 8, 160]] [1, 48, 8, 160] 192
ConvBNLayer-8 [[1, 16, 8, 160]] [1, 48, 8, 160] 0
Conv2D-11 [[1, 48, 8, 160]] [1, 48, 8, 160] 432
BatchNorm-9 [[1, 48, 8, 160]] [1, 48, 8, 160] 192
ConvBNLayer-9 [[1, 48, 8, 160]] [1, 48, 8, 160] 0
Conv2D-12 [[1, 48, 8, 160]] [1, 16, 8, 160] 768
BatchNorm-10 [[1, 16, 8, 160]] [1, 16, 8, 160] 64
ConvBNLayer-10 [[1, 48, 8, 160]] [1, 16, 8, 160] 0
ResidualUnit-3 [[1, 16, 8, 160]] [1, 16, 8, 160] 0
Conv2D-13 [[1, 16, 8, 160]] [1, 48, 8, 160] 768
BatchNorm-11 [[1, 48, 8, 160]] [1, 48, 8, 160] 192
ConvBNLayer-11 [[1, 16, 8, 160]] [1, 48, 8, 160] 0
Conv2D-14 [[1, 48, 8, 160]] [1, 48, 4, 160] 1,200
BatchNorm-12 [[1, 48, 4, 160]] [1, 48, 4, 160] 192
ConvBNLayer-12 [[1, 48, 8, 160]] [1, 48, 4, 160] 0
AdaptiveAvgPool2D-2 [[1, 48, 4, 160]] [1, 48, 1, 1] 0
Conv2D-15 [[1, 48, 1, 1]] [1, 12, 1, 1] 588
Conv2D-16 [[1, 12, 1, 1]] [1, 48, 1, 1] 624
SEModule-2 [[1, 48, 4, 160]] [1, 48, 4, 160] 0
Conv2D-17 [[1, 48, 4, 160]] [1, 24, 4, 160] 1,152
BatchNorm-13 [[1, 24, 4, 160]] [1, 24, 4, 160] 96
ConvBNLayer-13 [[1, 48, 4, 160]] [1, 24, 4, 160] 0
ResidualUnit-4 [[1, 16, 8, 160]] [1, 24, 4, 160] 0
Conv2D-18 [[1, 24, 4, 160]] [1, 120, 4, 160] 2,880
BatchNorm-14 [[1, 120, 4, 160]] [1, 120, 4, 160] 480
ConvBNLayer-14 [[1, 24, 4, 160]] [1, 120, 4, 160] 0
Conv2D-19 [[1, 120, 4, 160]] [1, 120, 4, 160] 3,000
BatchNorm-15 [[1, 120, 4, 160]] [1, 120, 4, 160] 480
ConvBNLayer-15 [[1, 120, 4, 160]] [1, 120, 4, 160] 0
AdaptiveAvgPool2D-3 [[1, 120, 4, 160]] [1, 120, 1, 1] 0
Conv2D-20 [[1, 120, 1, 1]] [1, 30, 1, 1] 3,630
Conv2D-21 [[1, 30, 1, 1]] [1, 120, 1, 1] 3,720
SEModule-3 [[1, 120, 4, 160]] [1, 120, 4, 160] 0
Conv2D-22 [[1, 120, 4, 160]] [1, 24, 4, 160] 2,880
BatchNorm-16 [[1, 24, 4, 160]] [1, 24, 4, 160] 96
ConvBNLayer-16 [[1, 120, 4, 160]] [1, 24, 4, 160] 0
ResidualUnit-5 [[1, 24, 4, 160]] [1, 24, 4, 160] 0
Conv2D-23 [[1, 24, 4, 160]] [1, 120, 4, 160] 2,880
BatchNorm-17 [[1, 120, 4, 160]] [1, 120, 4, 160] 480
ConvBNLayer-17 [[1, 24, 4, 160]] [1, 120, 4, 160] 0
Conv2D-24 [[1, 120, 4, 160]] [1, 120, 4, 160] 3,000
BatchNorm-18 [[1, 120, 4, 160]] [1, 120, 4, 160] 480
ConvBNLayer-18 [[1, 120, 4, 160]] [1, 120, 4, 160] 0
AdaptiveAvgPool2D-4 [[1, 120, 4, 160]] [1, 120, 1, 1] 0
Conv2D-25 [[1, 120, 1, 1]] [1, 30, 1, 1] 3,630
Conv2D-26 [[1, 30, 1, 1]] [1, 120, 1, 1] 3,720
SEModule-4 [[1, 120, 4, 160]] [1, 120, 4, 160] 0
Conv2D-27 [[1, 120, 4, 160]] [1, 24, 4, 160] 2,880
BatchNorm-19 [[1, 24, 4, 160]] [1, 24, 4, 160] 96
ConvBNLayer-19 [[1, 120, 4, 160]] [1, 24, 4, 160] 0
ResidualUnit-6 [[1, 24, 4, 160]] [1, 24, 4, 160] 0
Conv2D-28 [[1, 24, 4, 160]] [1, 64, 4, 160] 1,536
BatchNorm-20 [[1, 64, 4, 160]] [1, 64, 4, 160] 256
ConvBNLayer-20 [[1, 24, 4, 160]] [1, 64, 4, 160] 0
Conv2D-29 [[1, 64, 4, 160]] [1, 64, 4, 160] 1,600
BatchNorm-21 [[1, 64, 4, 160]] [1, 64, 4, 160] 256
ConvBNLayer-21 [[1, 64, 4, 160]] [1, 64, 4, 160] 0
AdaptiveAvgPool2D-5 [[1, 64, 4, 160]] [1, 64, 1, 1] 0
Conv2D-30 [[1, 64, 1, 1]] [1, 16, 1, 1] 1,040
Conv2D-31 [[1, 16, 1, 1]] [1, 64, 1, 1] 1,088
SEModule-5 [[1, 64, 4, 160]] [1, 64, 4, 160] 0
Conv2D-32 [[1, 64, 4, 160]] [1, 24, 4, 160] 1,536
BatchNorm-22 [[1, 24, 4, 160]] [1, 24, 4, 160] 96
ConvBNLayer-22 [[1, 64, 4, 160]] [1, 24, 4, 160] 0
ResidualUnit-7 [[1, 24, 4, 160]] [1, 24, 4, 160] 0
Conv2D-33 [[1, 24, 4, 160]] [1, 72, 4, 160] 1,728
BatchNorm-23 [[1, 72, 4, 160]] [1, 72, 4, 160] 288
ConvBNLayer-23 [[1, 24, 4, 160]] [1, 72, 4, 160] 0
Conv2D-34 [[1, 72, 4, 160]] [1, 72, 4, 160] 1,800
BatchNorm-24 [[1, 72, 4, 160]] [1, 72, 4, 160] 288
ConvBNLayer-24 [[1, 72, 4, 160]] [1, 72, 4, 160] 0
AdaptiveAvgPool2D-6 [[1, 72, 4, 160]] [1, 72, 1, 1] 0
Conv2D-35 [[1, 72, 1, 1]] [1, 18, 1, 1] 1,314
Conv2D-36 [[1, 18, 1, 1]] [1, 72, 1, 1] 1,368
SEModule-6 [[1, 72, 4, 160]] [1, 72, 4, 160] 0
Conv2D-37 [[1, 72, 4, 160]] [1, 24, 4, 160] 1,728
BatchNorm-25 [[1, 24, 4, 160]] [1, 24, 4, 160] 96
ConvBNLayer-25 [[1, 72, 4, 160]] [1, 24, 4, 160] 0
ResidualUnit-8 [[1, 24, 4, 160]] [1, 24, 4, 160] 0
Conv2D-38 [[1, 24, 4, 160]] [1, 144, 4, 160] 3,456
BatchNorm-26 [[1, 144, 4, 160]] [1, 144, 4, 160] 576
ConvBNLayer-26 [[1, 24, 4, 160]] [1, 144, 4, 160] 0
Conv2D-39 [[1, 144, 4, 160]] [1, 144, 2, 160] 3,600
BatchNorm-27 [[1, 144, 2, 160]] [1, 144, 2, 160] 576
ConvBNLayer-27 [[1, 144, 4, 160]] [1, 144, 2, 160] 0
AdaptiveAvgPool2D-7 [[1, 144, 2, 160]] [1, 144, 1, 1] 0
Conv2D-40 [[1, 144, 1, 1]] [1, 36, 1, 1] 5,220
Conv2D-41 [[1, 36, 1, 1]] [1, 144, 1, 1] 5,328
SEModule-7 [[1, 144, 2, 160]] [1, 144, 2, 160] 0
Conv2D-42 [[1, 144, 2, 160]] [1, 48, 2, 160] 6,912
BatchNorm-28 [[1, 48, 2, 160]] [1, 48, 2, 160] 192
ConvBNLayer-28 [[1, 144, 2, 160]] [1, 48, 2, 160] 0
ResidualUnit-9 [[1, 24, 4, 160]] [1, 48, 2, 160] 0
Conv2D-43 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824
BatchNorm-29 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152
ConvBNLayer-29 [[1, 48, 2, 160]] [1, 288, 2, 160] 0
Conv2D-44 [[1, 288, 2, 160]] [1, 288, 2, 160] 7,200
BatchNorm-30 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152
ConvBNLayer-30 [[1, 288, 2, 160]] [1, 288, 2, 160] 0
AdaptiveAvgPool2D-8 [[1, 288, 2, 160]] [1, 288, 1, 1] 0
Conv2D-45 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Conv2D-46 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
SEModule-8 [[1, 288, 2, 160]] [1, 288, 2, 160] 0
Conv2D-47 [[1, 288, 2, 160]] [1, 48, 2, 160] 13,824
BatchNorm-31 [[1, 48, 2, 160]] [1, 48, 2, 160] 192
ConvBNLayer-31 [[1, 288, 2, 160]] [1, 48, 2, 160] 0
ResidualUnit-10 [[1, 48, 2, 160]] [1, 48, 2, 160] 0
Conv2D-48 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824
BatchNorm-32 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152
ConvBNLayer-32 [[1, 48, 2, 160]] [1, 288, 2, 160] 0
Conv2D-49 [[1, 288, 2, 160]] [1, 288, 2, 160] 7,200
BatchNorm-33 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152
ConvBNLayer-33 [[1, 288, 2, 160]] [1, 288, 2, 160] 0
AdaptiveAvgPool2D-9 [[1, 288, 2, 160]] [1, 288, 1, 1] 0
Conv2D-50 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Conv2D-51 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
SEModule-9 [[1, 288, 2, 160]] [1, 288, 2, 160] 0
Conv2D-52 [[1, 288, 2, 160]] [1, 48, 2, 160] 13,824
BatchNorm-34 [[1, 48, 2, 160]] [1, 48, 2, 160] 192
ConvBNLayer-34 [[1, 288, 2, 160]] [1, 48, 2, 160] 0
ResidualUnit-11 [[1, 48, 2, 160]] [1, 48, 2, 160] 0
Conv2D-53 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824
BatchNorm-35 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152
ConvBNLayer-35 [[1, 48, 2, 160]] [1, 288, 2, 160] 0
MaxPool2D-1 [[1, 288, 2, 160]] [1, 288, 1, 80] 0
===============================================================================
Total params: 259,056
Trainable params: 246,736
Non-trainable params: 12,320
-------------------------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 44.38
Params size (MB): 0.99
Estimated Total Size (MB): 45.48
-------------------------------------------------------------------------------
{'total_params': 259056, 'trainable_params': 246736}
# 图片输入骨干网络
backbone = MobileNetV3()
# 将numpy数据转换为Tensor
input_data = paddle.to_tensor([padding_im])
# 骨干网络输出
feature = backbone(input_data)
# 查看feature map的纬度
print("backbone output:", feature.shape)
backbone output: [1, 288, 1, 80]
- neck
neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征( 源码位置 ):
class Im2Seq(nn.Layer):
def __init__(self, in_channels, **kwargs):
"""
图像特征转换为序列特征
:param in_channels: 输入通道数
"""
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
assert H == 1
x = x.squeeze(axis=2)
x = x.transpose([0, 2, 1]) # (NWC)(batch, width, channels)
return x
class EncoderWithRNN(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x):
x, _ = self.lstm(x)
return x
class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, hidden_size=48, **kwargs):
"""
序列编码
:param in_channels: 输入通道数
:param hidden_size: 隐藏层size
"""
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.encoder = EncoderWithRNN(
self.encoder_reshape.out_channels, hidden_size)
self.out_channels = self.encoder.out_channels
def forward(self, x):
x = self.encoder_reshape(x)
x = self.encoder(x)
return x
neck = SequenceEncoder(in_channels=288)
sequence = neck(feature)
print("sequence shape:", sequence.shape)
sequence shape: [1, 80, 96]
- head
预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布,本示例仅支持模型识别小写英文字母和数字(26+10)36个类别(源码位置):
class CTCHead(nn.Layer):
def __init__(self,
in_channels,
out_channels,
**kwargs):
"""
CTC 预测层
:param in_channels: 输入通道数
:param out_channels: 输出通道数
"""
super(CTCHead, self).__init__()
self.fc = nn.Linear(
in_channels,
out_channels)
# 思考:out_channels 应该等于多少?
self.out_channels = out_channels
def forward(self, x):
predicts = self.fc(x)
result = predicts
if not self.training:
predicts = F.softmax(predicts, axis=2)
result = predicts
return result
在网络随机初始化的情况下,输出结果是无序的,经过SoftMax之后,可以得到各时间步上的概率最大的预测结果,其中:pred_id
代表预测的标签ID,pre_scores
代表预测结果的置信度:
ctc_head = CTCHead(in_channels=96, out_channels=37)
predict = ctc_head(sequence)
print("predict shape:", predict.shape)
result = F.softmax(predict, axis=2)
pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print("pred_id:", pred_id)
print("pred_scores:", pred_socres)
predict shape: [1, 80, 37]
pred_id: Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,
[[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])
pred_scores: Tensor(shape=[1, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
[[0.03683758, 0.03368053, 0.03604801, 0.03504696, 0.03696444, 0.03597261, 0.03925638, 0.03650934, 0.03873367, 0.03572492, 0.03543066, 0.03618268, 0.03805700, 0.03496549, 0.03329032, 0.03565763, 0.03846950, 0.03922413, 0.03970327, 0.03638541, 0.03572393, 0.03618102, 0.03565401, 0.03636984, 0.03691722, 0.03718850, 0.03623354, 0.03877943, 0.03731697, 0.03563465, 0.03447339, 0.03365586, 0.03312979, 0.03285240, 0.03273271, 0.03269565, 0.03269779, 0.03271412, 0.03273287, 0.03274929, 0.03276210, 0.03277146, 0.03277802, 0.03278249, 0.03278547, 0.03278742, 0.03278869, 0.03278949, 0.03279000, 0.03279032, 0.03279052, 0.03279064, 0.03279071, 0.03279077, 0.03279081, 0.03279087, 0.03279094, 0.03279106, 0.03279124, 0.03279152, 0.03279196, 0.03279264, 0.03279363, 0.03279509, 0.03279718, 0.03280006, 0.03280392, 0.03280888, 0.03281487, 0.03282148, 0.03282760, 0.03283087, 0.03282646, 0.03280647, 0.03275031, 0.03263619, 0.03242587, 0.03194289, 0.03122442, 0.02986610]])
- 后处理
识别网络最终返回的结果是各个时间步上的最大索引值,最终期望的输出是对应的文字结果,因此CRNN的后处理是一个解码过程,主要逻辑如下:
def decode(text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
character = "-0123456789abcdefghijklmnopqrstuvwxyz"
result_list = []
# 忽略tokens [0] 代表ctc中的blank位
ignored_tokens = [0]
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
# 合并blank之间相同的字符
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
# 将解码结果存在char_list内
char_list.append(character[int(text_index[batch_idx][
idx])])
# 记录置信度
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
# 输出结果
result_list.append((text, np.mean(conf_list)))
return result_list
以 head 部分随机初始化预测出的结果为例,进行解码得到:
pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print(pred_id)
decode_out = decode(pred_id, pred_socres)
print("decode out:", decode_out)
Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,
[[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])
decode out: [('mrmmmmmmmmmtttummmmmmmummmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm4', 0.034180813)]
小测试: 如果输入模型训练好的index,解码结果是否正确呢?
# 替换模型预测好的结果
right_pred_id = paddle.to_tensor([['xxxxxxxxxxxxx']])
tmp_scores = paddle.ones(shape=right_pred_id.shape)
out = decode(right_pred_id, tmp_scores)
print("out:",out)
out: [('pain', 1.0)]
上述步骤完成了网络的搭建,也实现了一个简单的前向预测过程。
没有经过训练的网络无法正确预测结果,因此需要定义损失函数、优化策略,将整个网络run起来,下面将详细介绍网络训练原理。
3. 训练原理详解
3.1 准备训练数据
PaddleOCR 支持两种数据格式:
lmdb
用于训练以lmdb格式存储的数据集(LMDBDataSet);通用数据
用于训练以文本文件存储的数据集(SimpleDataSet);
本次只介绍通用数据格式读取
训练数据的默认存储路径是 ./train_data
, 执行以下命令解压数据:
!cd /home/aistudio/work/train_data/ && tar xf ic15_data.tar
解压完成后,训练图片都在同一个文件夹内,并有一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:
" 图像文件名 图像标注信息 "
train/word_1.png Genaxis Theatre
train/word_2.png [06]
...
注意: txt文件中默认将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。
数据集应有如下文件结构:
|-train_data
|-ic15_data
|- rec_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
|- rec_gt_test.txt
|- test
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
确认配置文件中的数据路径是否正确,以 rec_icdar15_train.yml为例:
Train:
dataset:
name: SimpleDataSet
# 训练数据根目录
data_dir: ./train_data/ic15_data/
# 训练数据标签
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100] # [3,32,320]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 256
drop_last: True
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: SimpleDataSet
# 评估数据根目录
data_dir: ./train_data/ic15_data
# 评估数据标签
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 4
use_shared_memory: False
3.2 数据预处理
送入网络的训练数据,需要保证一个batch内维度一致,同时为了不同维度之间的特征在数值上有一定的比较性,需要对数据做统一尺度缩放和归一化。
为了增加模型的鲁棒性,抑制过拟合提升泛化性能,需要实现一定的数据增广。
- 缩放和归一化
第二节中已经介绍了相关内容,这是图片送入网络之前的最后一步操作。调用 resize_norm_img
完成图片缩放、padding和归一化。
- 数据增广
PaddleOCR中实现了多种数据增广方式,如:颜色反转、随机切割、仿射变化、随机噪声等等,这里以简单的随机切割为例,更多增广方式可参考:rec_img_aug.py
def get_crop(image):
"""
random crop
"""
import random
h, w, _ = image.shape
top_min = 1
top_max = 8
top_crop = int(random.randint(top_min, top_max))
top_crop = min(top_crop, h - 1)
crop_img = image.copy()
ratio = random.randint(0, 1)
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
crop_img = crop_img[0:h - top_crop, :, :]
return crop_img
# 读图
raw_img = cv2.imread("/home/aistudio/work/word_1.png")
plt.figure()
plt.subplot(2,1,1)
# 可视化原图
plt.imshow(raw_img)
# 随机切割
crop_img = get_crop(raw_img)
plt.subplot(2,1,2)
# 可视化增广图
plt.imshow(crop_img)
plt.show()
3.3 训练主程序
模型训练的入口代码是 train.py,它展示了训练中所需的各个模块: build dataloader
, build post process
, build model
, build loss
, build optim
, build metric
,将各部分串联后即可开始训练:
- 构建 dataloader
训练模型需要将数据组成指定数目的 batch ,并在训练过程中依次 yield 出来,本例中调用了 PaddleOCR 中实现的 SimpleDataSet
基于原始代码稍作修改,其返回单条数据的主要逻辑如下
def __getitem__(data_line, data_dir):
import os
mode = "train"
delimiter = '\t'
try:
substr = data_line.strip("\n").split(delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
# 预处理操作,先注释掉
# outs = transform(data, self.ops)
outs = data
except Exception as e:
print("When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
return outs
假设当前输入的标签为 train/word_1.png Genaxis Theatre
, 训练数据的路径为 /home/aistudio/work/train_data/ic15_data/
, 解析出的结果是一个字典,里面包含 img_path
label
image
三个字段:
data_line = "train/word_1.png Genaxis Theatre"
data_dir = "/home/aistudio/work/train_data/ic15_data/"
item = __getitem__(data_line, data_dir)
print(item)
{'img_path': '/home/aistudio/work/train_data/ic15_data/train/word_1.png', 'label': 'Genaxis Theatre', 'image': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00Y\x00\x00\x00\x0e\x08\x02\x00\x00\x00\xcb\xe2\'\xb7\x00\x00\x00\x01sRGB\x00\xae\xce\x1c\xe9\x00\x00\x00\x04gAMA\x00\x00\xb1\x8f\x0b\xfca\x05\x00\x00\x00 cHRM\x00\x00z&\x00\x00\x80\x84\x00\x00\xfa\x00\x00\x00\x80\xe8\x00\x00u0\x00\x00\xea`\x00\x00:\x98\x00\x00\x17p\x9c\xbaQ<\x00\x00\x0bmIDATHK\x8d\x96\xf9S[\xd7\x15\x80\x01\xa7\x93\xa4\xfd1\x99L\xea\x80\xc4\xa2]B\x0bb\xdf\x84\x04\x18\x8c\x01\xb3\x8aE\xec\x12\x02\t\xb4KhC\xfb\xbe=\xed\xbb\x04\xc2l&N\xd2\xb4\x93i\x9bv\xa6\x7fL\xdb\xe9d\xe2N\xd3d<u6C\x8f\xc0I\x9c\xb1\x7f\xc8\x9d\x8f3\xf7\xdd\xf7\xee\x9d{\xbfw\xceC\x95\xd3\xc3\xe3\x04\x02\xe1\xad\xb7\xde\xbazv\xf9\xcd7\xdf<~\xfc\xf8\xef\xff\xfc\xc7\xe7\x9f\xff\xfb\x7fO\x9f>y\xf2\xe4\xeb\xaf\xbf\xbe\xbc\xbc|\xedWU\xaf]\xb7\xab\xab\xab\xca\xca\xab\x8a\x8a\x8a\xca\xca\xca\xeb\x08\x97\xb7\xa0sY\xf1\xda\xf7\xdf?\x83vUYQUU\x05wa\xd6w\xcf\x9e\xdd\xbau\xeb\xf5\xd7_G\xa1Po\xfe\xe6\xd7\x9f}\xf6\x19,\x08+|\xf9\xe5\x97U\xe5\xd9\xcf\xdb\x1bo\xbc\xf1\xf6\xdbo\xdf\xbe}\x1b\xf6\xf0\xce;\xef\xc0"\xdf~\xfb\xedw\xdf}\xf7\xf4\xe9\xd3\xaf\xbe\xfa\xea\x8b/\xbex\xf2\xe4\xbf\xffy\xfc\xaf[\x95\xcf`eh\xcf\xa7]\x96\xb7\x01\x8f\x95\xb7q\xf5\xe3~*\xe16l\xa6\xb2\xf2\xcd\xab\x8a[\xb0\rh\xd7\xfb)?\x00\rV\xae\xac*O\xfc\xe1\x08\x97\x15\x15U\x95W\x97\xf0W\x05\xf3LJ]&\x14\x07r\xa1D\xdc\x1d\xd2\x8a\x15\x0b\xe3\\J-\x89\x88&\xe0\xaa\xb1\xb8j\x0c\xbe\x06K@\xe1\x08h\x0c\x1e\xd5@@\xd7\x03xT\x1d\x1e\x85\xc6\xd5\xa0p\xa8jlu\r\xb6\x1aM\xc1R\xb0\xb5x\xf4\xed\xba\x9aw\xd1\x10\xd1\xd5\r5\xb7\xeb\xde{\x17M\xc0R\x18\xf4\xd6\xd9\x99\xc5\x95UAw\x0f\xa7\x91\xca\xa4\xd2\x9a\xd1\xb5\xd8\xea\xf7jQ\xa82ht\x1d\x99\xdc8<<\xa2P\xa8\xc2\xe1\xe8\xe9\xe9\xf9\xe1\xe1Q:\x9d\x8dF\xe3\x1e\x8f\x0f\x06gg\xe7\xfazY4\n\x9eN\xc6\xd2H\x18\x80Jl\x00h\x84z\x80\xd4P\x03P\x1aP@#\x06\rP\x1b\x80\xba\x16\x12\x9d\x81\xa7Q\xeaI\xb0\xf3\xf2\xfe\xd1\x18r=\xbe\x11C\xa4bIT,\x81\x86#\xde@\xc7\x13\xe8x\x12\x03Oh&\x90\x80\n\x97\xc1\x1a\xb0\xba5b\xb9f[\xa6\x12J\x16F\xa79m\xacNz\x07\xa7\xbd\x7f|\xe0>orq}nmy\x9a73256p\xf7\x1egh\x84=0\xd4\xcb\x1e\xe8\xeaf\xb7w\xf4\xb40\xbb\x98\x8c\xf6\xa6\xa6\xb1\xe1\xf1;\xec\xe1\xde\x8e\xbe\xae\xd6\x9e\xaevVO\'\xbb\xbb\xa3\xaf\xb3\xad\xb7\xbd\xad\xa7\x8f5\xb8\xc1\xdf\xde\x16I9\xfd\xc3\xadm\xdd}\xec;\x10{{8,\x16\xbb\xb3\xb3\x9b\xc1`R\xa9\xf4\xbe>\xce\xc6\x86\xc0d\xb2 H\x04\x14X\xadv\x8b\xc5\x06\x972\x99\x82\xcb\x9d\xe7@c\xc3\x84vVw[w\x1b\xa3\x9dIke\x90\x9bi\xe4\x16*\x81\x8ak\xa0\xe2\xeb\xe88\x0c\x1d_\x0f\x91Ah`\xe0\xb1\x0c<N"\xd8\x12\xf0\xd6f\xee\x8d\x0fv\xb3:\xe8L&\xa9\xb1\x89H\x02\xa8X\\\x19\x1c\x86\x86\xc7\x02t\x02\xee\x06&\x11\\P*\xfc\x16\x97\xdbh\xdb\\\\\x1b\xeb\x1bj\'7\xb5\x91\x18c\x9c\x11\xadDg\xd5\xb9\x10w<\x1d\xce\x05\x1c\x88tK&\x11J\xbd6o\xc8\x13\xc8\xc4\x12>\x87K\xabP\xcdO\xcd\x88\xf8|\xbdJ\x15\xf2\xf8b\xa1\x84\xdb\xea\x97l)\xf6\xf7\xac\xd2\x1d\xb5\x90\xbf;9>\'X\x17\xabU\xc6M\xc1\xae\\\xae\xdb\xda\x92\t\x85R\x91Ha\xb7\xfbuz\xb3i\xdff1;dR\xd5\xf8\xd8T#\x85\x81\xaa\xa9\xef\xec\xe8\x15\x8b\xa4>oH.S//\xadK%J$\x14\x8b\xc7\xd20\xe2\xf3\x05\x90P`\x89\xc7\x9d\x9d\x990\x194\xbb\xdb\x9bF\x9d\xca\xa8U\x8b\x04k3\xf7GW\x17\xb9\x06\xa5\xdc\xaa\xd3l\xaf\xadl,\xce9\x8c\xfa\xb3\x83\xc2I\xbe\xe0w:\xb5J\x99d{sG\xc8WJ\xc4R\xd1&\x8f;\x05S\xb4J\xa9A#S\xcbD"\xc1\xca\xf4\xf8p?\xab\xf5No\x17\xb8`\x82\x0b\xcf\xbe\x1d\\\xec\xac\t\xf9\xdc\xe5\xd5\xa9\x85\xe5\x899!\x8f\xef4\x94E\xe4\xe3\xa5\x93\xc2E.V4i\xacz\xc5~,\x10?H\x1f\x94\xb2\xa5\\"\x93\x8a$T\x12\x85Eo\n\xfb\x02\xa5\xdc\xc1\xf1\xc1\xe9\xc9\xc1C\xc4\x1b\x8b\x87\xd2\x01W8\xe0\x8d\x1a\xb4V\x9b\xd9c6:\x95\n\xbd\xd3\x110\x1a\x1dz\xbdM\xa7\xb3\xc6\xe3\xf9P(\x19\xf0G\xcc&\xfb\x96pgrb\xb6\xa5\xb9\xa3\xa6\xba\x8eAoY[\x15\xd8mn\x10\x04.\x14rM&]8*\x9d\xa6S\xf9L&\x97\xcf\xe6\xe42\x89T\xb2\x93I\'\x11\xbf\xaf\x90Ig\x12q\xa7\xd5bP\xab=6[>\x99<+\x1e\xc4\xfd\x01\x9dL\xae\x95\xca\x02\x0e\xdb\x87\x0fO\x83\x1e\xe7\xbeV\xe5uX\xd21$\x9f\x8a\xa5\xa2\x81\xa0\xc7\x96\x8e\x05\x8f\n\xa9\xd3\xa3\xccI)\x9dM\x06\xf7\x94\xa2\xd9\xc9;c\xc3\x1cH\x19&\x81Zv\x818|z\xa9\xda\xb1g\x8a\xb9\x82~\x93\xd3\xa2\xde\xcf"\x85\x83\xd4\xe9i\xf1\xfd\x8b\xa3\x8f\x8e\xb2g\x01G$\xe8\x8c\x9e\x1f\xbe\x0f9\xb2\xaf19M\xae|\xa2`T\x1b\xdd\x16w.\x9e;/\x9d?:\xfe\xe8\xcf\x1f\xff\xf5({\x02:\x82\xee\xc8\x83\xc2Y!UJ\xc5\x0b\xfbz\xbbVm\x8a\x84R~_\xd4a\xf7[\xcc\xee|\xeeA6S\x8aG\xd2&\x83M\xb0\xbe=39\x0f\xd5\x04\xdf\x17\n\x91\xbe\xc0]v\xd9}\n\xa9fyq].Q\xa7\xe2\xb9D4c\xb7\xb8\xddN_<\x9eT\xa9T:\x9d\xee\xe4\xe4\xa4\x98/\\\x9c?<:,\x85C\x88`m]\xbe+1\x1b\x8c\xc9p\xf4(_\x0c\xb8<6\xa3)\x1b\x8f?:;\xd6\xab\xe5\xfc\x95E\xbbI\x9fKF\x11\x9f\xcb\xb8\'\xd7\xa9$`$\x9b@\x8a\xd9\xc8\xc5i\xe1\xfc8g7kxs\xa3\xf7G\x06\x9a\x88\x94&\xe2\xb5\x8b\xa0\xcd\xa3\x10\xee\x82\x8b\x07\xc9\xc2a,\x8b8\x02\xb9p1\x1b-%C\xb9\x98?\xed\xb7\x87M\x1a\xbb\xd7\x1a:?\xfc\xc0m\x0e\xac/\x08w\x05\x8a\\\xec\xd0at\x07\x9d\xe1\xc3\xf4\x83\x8b\xa3\x0f\xc0Q1Yr\x9b}|\x1e\xdc\x95\xa5"9H\x93b\xfaH.V\xef\x8a\x94>w\xd8i\xf3C\xa6\xc8eZ$\x90\x005\xf1H\xd6\xebB\xd4R\xdd\x02w\xb5\xa7\x8d]\x87\xc2\xd3\x88\xcc\xd9)\x9e\xdf\x89@\x89-L/+\xa4\xdad8\x0bSV\x17\xf8\xbb"y*\x99\xb3Y]n\x97\xff\xe2\xe1\x87\xb9L\xb1\x98/E\x90\xb8\xdd\xea\xda\\\x13*%*\x9dJ\x1f\xf6G\xe0\x95 \x1eD\xaf\xd4G\xfc\xc8\xe9\xe1!\x94\xc6\xd4\xe8\xa8A\xa3LE\x11\x9dJ>?=\xb18;Y\xcc$R\xd1P4\xe8\xcd%\xc3\x80F.\x86\xa4`w\xb7\x97]\x10\xe8\x15P ^\x93C\xbc\xbai\x94j\n\xe1\x14\xb8\x88{#p\xfe\xa0;\xe1\xb1\x85\xad\x06\xafFf\x12\xf1\x15*\x891\x16\xcc\x01K\\\xc1\xea\xfcv"Tt\xec\x07=\xd6H,\x90/$\x8e\xddF\x1f\xa4\x92\xcf\x1c\x14\xf2D:\x89\xf1A\xee\xfc\xd3\xdf\xff\xed\x93\x0f>U\x8a\xb5\xbbB\xa5\xc7\x16\xb2\xef{\xb5\n\x93xK\xe9s\x84C\x9ex1}|z\xf8\x08D\x0bV\xc4\xdd\xcdl\\-\xa5\x99\xd21yo\x0e\xb2O\xb1\xa3]\x9e\xdd\xb0\x1a\xdcg\x07\x8fLZ\x07\xf4\x8d\x1a\x1b\xa4R \x18\xf5\x07"\xa5\xa33\x9f\x17q\xd8\xbd*\xa5nG,\x97\x88\x15\x0e\xab\'\x12\x88\x9d\x1c\x9e}\xf2\xbb?F\x83q\xc9\xb6L\xab\xd4\xe7S\x05\xa9H2?\xbd`1\x98\x0f\xb2E\x93n\x7f~zne\x91\x97\x8a\xc5\xa3\xa1`\xd0\xebID\xc2\xf10\xa2\x94JF\x87\x87\xd8=\xddt\x12\x95N\xa4W\x98U\xfa\x98\'\x04:R\xbeH)\x9eCl\xde}\xa8\xf0}_\xc4\x9fI\x86\x8ba_\xdaf\xf4I\xb64\n\xb1\xc1g\x8f\x85\xdc\xa9\xb9\x89\x8d\xa9{\xcb\x1eK\xc2\xbc\x170i\xfcvC8\xee/\xfa-\xe1\x847\x15\xb4"R\xbe\xc2\xa4\xb4\x06m\xe1|\xe4 \x85\xe4\x1dF\xafFb\x00\xb3N\xb3\x1f\x8e\x04\xe7\x0c\xfb\x92\xf1P\xf6A\xee\xe2\xe3\x87\x7f\xca\xc5\x8e\xf8<q;\x9dE\xa8\xa51\xc9\x9dC\xac\xf1\x80#f\xd28E\xebr\xb79\xf4\xe8\xf8\x0f.Spsy\xd7k\x8f\x04|\t\x87#`\xb3\x052 \xc5\x97\x08\xf8cF\xbdS\xbbgVH\xf5\x90_\xd9\xf8\xe1\xc5\xf1\x87\x9f|\xf4\x97R\xee\x04\xbc\xfb\x1c\xa1B\xea\x01D\xad\xc2\x08\x11\xd2\xd3\xa4\xb3p\'\xe7\xe7\xa6\x16\x82\x1e$\x19I\xe5\x92\xf9B:\x0f\x1d\x95T9~\xf7>\xab\xb3\x0fD\x94]\xf0\xe7W@\x07\xfcg\x85\x8f\x85\x07jxn\xe5\xfe\xc0\xa8`I\xa4W\xd9 /\xdcVD+\xb7n\xf0v\xf8KR\xa3\xda\xad\x91Z\xc7\x06\x17\xef\xf4Nk$N\xd9\x96y{M+\xde\xd0{\xcc\xf1<r\x98\xf4\xa4\xc2\xb6\x88^bth]J\xa1Z\xb7k\x00)F\xb9\x99\xcf\xdb\x96okv\xf8\xf2\xad\xd5]\xa8/\x8b\xd6\t\x82\xc2\xceD\xc4\x95\xda\xdb\xd9\x1f\xe3L7\xe1\xdb\xf1\xd5\x8d4L\x0b\xbbuH\xb5\xad7\xc8\xac\xeb\xdc-\xc1\x82\xd8\xa9\xf7K\xf9\xea\xa9\xa1\xf9]\x81J.5\x8aD*\xa1P\xe9r\x85\x83\xc1L:}\xe4vF\xadf\x9flGo\xd9\xf7\x04\xbdP\x07\xf0\x1d(@\xc6\x99\xb4.\xb9Xg7\x06"\xbe\xac\xcb\x8c\x98\xb5\x1e\xb7%,\xdd\xde\x1b\x19\x98\xba\xd37n\xd1{R\x91\xc2a\xf6\x14\n\x10R\x15\n\xb0\xbb\x95\xd3\x88g\xd0\x89e*\xfa\xdb{GX\x83\xa3\xec\xa1\xf9\xb1\xe9\xf1\xfe\xbb\xad\xc4&\x1a\x86B\xc1\xd0\x06zG\xb8\x13K\xc0\xdd\xfe\xc9\x0e&\xbb\xafs\x987#\x98\x1cYl\xa1\xf6\x911\xad\x93#KC\xac\x99v\xfa`+\xb5\x7fiZ\xb8\xb5\xb4;3<3\xce\x19\xe7\xb4rF9c4,\xbd\x83\xd6y\x8f=\xba8\xc1\xebnf\rt\x0f\xf5\xb5\xf7\x03=-}K3\xab\xabs\x02\xf9\xa6fcn\xebN\xd7=\x06\xae\x95\x84\xa6a~K\x84\xd8\xcb\xec\x1f\xee\x19[\xe7\n\x07;G`|bpv\x94=\xd9L\xec\x18\xec\x19\xbd\xd37\xda\xd5\xde\xcf\xa4wq\xa7\x97\x84|\xa9Fi\\\x99\x17L\x8dr\xe9\xa4\xe6\x81\xde\xe1\x15\xee\xfa\xd6\xaa\x18>U\xcb\xb3kC\xac\x11\x1c\x8a\xdc\xc9\xe0\xac\xcf\xef\x00Sw\x17a{\xd3#\xbc\x0e:\xbb\x85\xd2C\xc5\xb6\xc2\x88pE\xb2\xbe\xb0=w\x7f\x052\xb1\x95\xdaC\'\xb44\x11\x98M\x04F\x05\xa9\x96\xf0"d4\x11.\xf1\xb5\xf8k\xa0\xf3\x02h\x12\xbe\x0c\xe5\x07\xa8x\xf4\r\x14\x1c\x8a\x88C\xe3_\t\xfc\x1e}\x19b\x1d\x85P\xdb\x08\xdc,u\xd3\'\xd6Qo:7\xfd\x1bH\xf5\xf0\xf3\xb1\x91\x8e\xa1\xd2\xb1\x8d\xd7\x90_\x8c\x0c\x1c\x05h\xc2\x91^\x84\x81\xa3\xd21\xf0F[\xaei\xfe!B\xa7\x99\x8emy\x15\xcc&\\S\x13\x8e^A\xa8\xfb\xd1\x05Xx\xce\xab]\xfc\xe8\xe5\xb9\x94\x1b5e0h\x1c\x06\x8d}9^[\xc0\xbd\x14\xf1\x84:"X~n\xf0\xda#,\x02\x83\x84Z2\xb1\x9e\x04\xa6 \x92\xea\x1bI\rd\x88\x94z2\x03\xd3\xc8\xc4\x90\x99\x18\xe2\xab\xc031?\x83\x81%2\xb04\x06\x96\xc1\xc02\x7fadb\x9b\x98/\xba\x80\x8c\xb8\x01\x8c\xfc\xfc\xfd\xdf$\xc2O\'\x7f!5\x9e\xdf\xc2\xa0\x08e\x11/E,\x9aX\x16\xf1R,\x8b@\xe3\xb0(\x00{\x13\xcb:j\xf1e\x11\xf5\xc4\xb2\x88\x06\x80L\xc6\x00\x94\xc6\x06\xc8\x05\xd0\xf1\n\xe8\r\xa4\x17\xf9\xe9\x99\xeb|\x81\x04\xf9e\x91\n"\x80\xff\x03\x99\xa0+\x94\xbd\xf0X\xa1\x00\x00\x00\x00IEND\xaeB`\x82'}
实现完单条数据返回逻辑后,调用 padde.io.Dataloader
即可把数据组合成batch,具体可参考 build_dataloader
- build model
build model 即搭建主要网络结构,具体细节如《2.3 代码实现》所述,本节不做过多介绍,各模块代码可参考modeling - build loss
CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可:
import paddle.nn as nn
class CTCLoss(nn.Layer):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
# blank 是 ctc 的无意义连接符
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
def forward(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
# 转置模型 head 层的预测结果,沿channel层排列
predicts = predicts.transpose((1, 0, 2)) #[80,1,37]
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
# 计算损失函数
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
loss = loss.mean()
return {'loss': loss}
- build post process
具体细节同样在《2.3 代码实现》有详细介绍,实现逻辑与之前一致。
- build optim
优化器使用 Adam
, 同样调用飞桨API: paddle.optimizer.Adam
- build metric
metric 部分用于计算模型指标,PaddleOCR的文本识别中,将整句预测正确判断为预测正确,因此准确率计算主要逻辑如下:
def metric(preds, labels):
correct_num = 0
all_num = 0
norm_edit_dis = 0.0
for (pred), (target) in zip(preds, labels):
pred = pred.replace(" ", "")
target = target.replace(" ", "")
if pred == target:
correct_num += 1
all_num += 1
correct_num += correct_num
all_num += all_num
return {
'acc': correct_num / all_num,
}
preds = ["aaa", "bbb", "ccc", "123", "456"]
labels = ["aaa", "bbb", "ddd", "123", "444"]
acc = metric(preds, labels)
print("acc:", acc)
# 五个预测结果中,完全正确的有3个,因此准确率应为0.6
acc: {'acc': 0.6}
将以上各部分组合起来,即是完整的训练流程:
def main(config, device, logger, vdl_writer):
# init dist environment
if config['Global']['distributed']:
dist.init_parallel_env()
global_config = config['Global']
# build dataloader
train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
"No Images in train dataset, please ensure\n" +
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+
"\t2. The annotation file and path in the configuration file are provided normally."
)
return
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
else:
valid_dataloader = None
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
# build optim
optimizer, lr_scheduler = build_optimizer(
config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
else:
scaler = None
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
4. 完整训练任务
4.1 启动训练
PaddleOCR 识别任务与检测任务类似,是通过配置文件传输参数的。
要进行完整的模型训练,首先需要下载整个项目并安装相关依赖:
# 克隆PaddleOCR代码
# !git clone https://gitee.com/paddlepaddle/PaddleOCR
# 修改代码运行的默认目录为 /home/aistudio/PaddleOCR
import os
os.chdir("/home/aistudio/PaddleOCR")
# 安装PaddleOCR第三方依赖
!pip install -r requirements.txt
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)
Collecting scikit-image==0.17.2
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d7/ee/753ea56fda5bc2a5516a1becb631bf5ada593a2dd44f21971a13a762d4db/scikit_image-0.17.2-cp37-cp37m-manylinux1_x86_64.whl (12.5 MB)
|████████████████████████████████| 12.5 MB 13.3 MB/s
[?25hRequirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)
Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)
Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.36.1)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)
Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)
Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)
Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)
Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)
Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)
Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)
Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)
Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)
Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)
Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)
Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)
Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)
Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)
Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)
Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)
Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)
Installing collected packages: scikit-image
Attempting uninstall: scikit-image
Found existing installation: scikit-image 0.19.1
Uninstalling scikit-image-0.19.1:
Successfully uninstalled scikit-image-0.19.1
Successfully installed scikit-image-0.17.2
创建软链,将训练数据放在PaddleOCR项目下:
!ln -s /home/aistudio/work/train_data/ /home/aistudio/PaddleOCR/
下载预训练模型:
为了加快收敛速度,建议下载训练好的模型在 icdar2015 数据上进行 finetune
!cd PaddleOCR/
# 下载MobileNetV3的预训练模型
!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 解压模型参数
!tar -xf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar
--2021-12-22 15:39:39-- https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 51200000 (49M) [application/x-tar]
Saving to: ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’
rec_mv3_none_bilstm 100%[===================>] 48.83M 15.5MB/s in 3.6s
2021-12-22 15:39:42 (13.7 MB/s) - ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’ saved [51200000/51200000]
启动训练命令很简单,指定好配置文件即可。另外在命令行中可以通过 -o
修改配置文件中的参数值。启动训练命令如下所示
其中:
Global.pretrained_model
: 加载的预训练模型路径Global.character_dict_path
: 字典路径(这里只支持26个小写字母+数字)Global.eval_batch_step
: 评估频率Global.epoch_num
: 总训练轮数
!python3 tools/train.py -c configs/rec/rec_icdar15_train.yml \
-o Global.pretrained_model=rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy \
Global.character_dict_path=ppocr/utils/ic15_dict.txt \
Global.eval_batch_step=[0,200] \
Global.epoch_num=40
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)
[2021/12/23 19:28:52] root INFO: Architecture :
[2021/12/23 19:28:52] root INFO: Backbone :
[2021/12/23 19:28:52] root INFO: model_name : large
[2021/12/23 19:28:52] root INFO: name : MobileNetV3
[2021/12/23 19:28:52] root INFO: scale : 0.5
[2021/12/23 19:28:52] root INFO: Head :
[2021/12/23 19:28:52] root INFO: fc_decay : 0
[2021/12/23 19:28:52] root INFO: name : CTCHead
[2021/12/23 19:28:52] root INFO: Neck :
[2021/12/23 19:28:52] root INFO: encoder_type : rnn
[2021/12/23 19:28:52] root INFO: hidden_size : 96
[2021/12/23 19:28:52] root INFO: name : SequenceEncoder
[2021/12/23 19:28:52] root INFO: Transform : None
[2021/12/23 19:28:52] root INFO: algorithm : CRNN
[2021/12/23 19:28:52] root INFO: model_type : rec
[2021/12/23 19:28:52] root INFO: Eval :
[2021/12/23 19:28:52] root INFO: dataset :
[2021/12/23 19:28:52] root INFO: data_dir : ./train_data/ic15_data
[2021/12/23 19:28:52] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 19:28:52] root INFO: name : SimpleDataSet
[2021/12/23 19:28:52] root INFO: transforms :
[2021/12/23 19:28:52] root INFO: DecodeImage :
[2021/12/23 19:28:52] root INFO: channel_first : False
[2021/12/23 19:28:52] root INFO: img_mode : BGR
[2021/12/23 19:28:52] root INFO: CTCLabelEncode : None
[2021/12/23 19:28:52] root INFO: RecResizeImg :
[2021/12/23 19:28:52] root INFO: image_shape : [3, 32, 100]
[2021/12/23 19:28:52] root INFO: KeepKeys :
[2021/12/23 19:28:52] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 19:28:52] root INFO: loader :
[2021/12/23 19:28:52] root INFO: batch_size_per_card : 256
[2021/12/23 19:28:52] root INFO: drop_last : False
[2021/12/23 19:28:52] root INFO: num_workers : 4
[2021/12/23 19:28:52] root INFO: shuffle : False
[2021/12/23 19:28:52] root INFO: use_shared_memory : False
[2021/12/23 19:28:52] root INFO: Global :
[2021/12/23 19:28:52] root INFO: cal_metric_during_train : True
[2021/12/23 19:28:52] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 19:28:52] root INFO: character_type : EN
[2021/12/23 19:28:52] root INFO: checkpoints : None
[2021/12/23 19:28:52] root INFO: debug : False
[2021/12/23 19:28:52] root INFO: distributed : False
[2021/12/23 19:28:52] root INFO: epoch_num : 40
[2021/12/23 19:28:52] root INFO: eval_batch_step : [0, 200]
[2021/12/23 19:28:52] root INFO: infer_img : doc/imgs_words_en/word_19.png
[2021/12/23 19:28:52] root INFO: infer_mode : False
[2021/12/23 19:28:52] root INFO: log_smooth_window : 20
[2021/12/23 19:28:52] root INFO: max_text_length : 25
[2021/12/23 19:28:52] root INFO: pretrained_model : rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
[2021/12/23 19:28:52] root INFO: print_batch_step : 10
[2021/12/23 19:28:52] root INFO: save_epoch_step : 3
[2021/12/23 19:28:52] root INFO: save_inference_dir : ./
[2021/12/23 19:28:52] root INFO: save_model_dir : ./output/rec/ic15/
[2021/12/23 19:28:52] root INFO: save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 19:28:52] root INFO: use_gpu : True
[2021/12/23 19:28:52] root INFO: use_space_char : False
[2021/12/23 19:28:52] root INFO: use_visualdl : False
[2021/12/23 19:28:52] root INFO: Loss :
[2021/12/23 19:28:52] root INFO: name : CTCLoss
[2021/12/23 19:28:52] root INFO: Metric :
[2021/12/23 19:28:52] root INFO: main_indicator : acc
[2021/12/23 19:28:52] root INFO: name : RecMetric
[2021/12/23 19:28:52] root INFO: Optimizer :
[2021/12/23 19:28:52] root INFO: beta1 : 0.9
[2021/12/23 19:28:52] root INFO: beta2 : 0.999
[2021/12/23 19:28:52] root INFO: lr :
[2021/12/23 19:28:52] root INFO: learning_rate : 0.0005
[2021/12/23 19:28:52] root INFO: name : Adam
[2021/12/23 19:28:52] root INFO: regularizer :
[2021/12/23 19:28:52] root INFO: factor : 0
[2021/12/23 19:28:52] root INFO: name : L2
[2021/12/23 19:28:52] root INFO: PostProcess :
[2021/12/23 19:28:52] root INFO: name : CTCLabelDecode
[2021/12/23 19:28:52] root INFO: Train :
[2021/12/23 19:28:52] root INFO: dataset :
[2021/12/23 19:28:52] root INFO: data_dir : ./train_data/ic15_data/
[2021/12/23 19:28:52] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:28:52] root INFO: name : SimpleDataSet
[2021/12/23 19:28:52] root INFO: transforms :
[2021/12/23 19:28:52] root INFO: DecodeImage :
[2021/12/23 19:28:52] root INFO: channel_first : False
[2021/12/23 19:28:52] root INFO: img_mode : BGR
[2021/12/23 19:28:52] root INFO: CTCLabelEncode : None
[2021/12/23 19:28:52] root INFO: RecResizeImg :
[2021/12/23 19:28:52] root INFO: image_shape : [3, 32, 100]
[2021/12/23 19:28:52] root INFO: KeepKeys :
[2021/12/23 19:28:52] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 19:28:52] root INFO: loader :
[2021/12/23 19:28:52] root INFO: batch_size_per_card : 256
[2021/12/23 19:28:52] root INFO: drop_last : True
[2021/12/23 19:28:52] root INFO: num_workers : 8
[2021/12/23 19:28:52] root INFO: shuffle : True
[2021/12/23 19:28:52] root INFO: use_shared_memory : False
[2021/12/23 19:28:52] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
[2021/12/23 19:28:52] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:28:52] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
W1223 19:28:52.737390 2821 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W1223 19:28:52.742431 2821 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 19:28:55] root INFO: loaded pretrained_model successful from rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy.pdparams
[2021/12/23 19:28:55] root INFO: train dataloader has 17 iters
[2021/12/23 19:28:55] root INFO: valid dataloader has 9 iters
[2021/12/23 19:28:55] root INFO: During the training process, after the 0th iteration, an evaluation is run every 200 iterations
[2021/12/23 19:28:55] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:29:00] root INFO: epoch: [1/40], iter: 10, lr: 0.000500, loss: 8.913865, acc: 0.195312, norm_edit_dis: 0.686087, reader_cost: 0.24545 s, batch_cost: 0.41798 s, samples: 2816, ips: 673.71029
[2021/12/23 19:29:01] root INFO: epoch: [1/40], iter: 16, lr: 0.000500, loss: 7.154922, acc: 0.222656, norm_edit_dis: 0.684251, reader_cost: 0.00006 s, batch_cost: 0.06422 s, samples: 1536, ips: 2391.80670
[2021/12/23 19:29:01] root INFO: save model in ./output/rec/ic15/latest
[2021/12/23 19:29:01] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 19:29:05] root INFO: epoch: [2/40], iter: 20, lr: 0.000500, loss: 6.198568, acc: 0.246094, norm_edit_dis: 0.688872, reader_cost: 0.20438 s, batch_cost: 0.34878 s, samples: 1024, ips: 293.59625
[2021/12/23 19:29:06] root INFO: epoch: [2/40], iter: 30, lr: 0.000500, loss: 4.117401, acc: 0.402344, norm_edit_dis: 0.739680, reader_cost: 0.00016 s, batch_cost: 0.11199 s, samples: 2560, ips: 2285.86780
^C
main proc 2870 exit, kill process group 2821
main proc 2869 exit, kill process group 2821
main proc 2868 exit, kill process group 2821
main proc 2867 exit, kill process group 2821
main proc 2866 exit, kill process group 2821
main proc 2865 exit, kill process group 2821
main proc 2864 exit, kill process group 2821
main proc 2821 exit, kill process group 2821
根据配置文件中设置的的 save_model_dir
字段,会有以下几种参数被保存下来:
output/rec/ic15
├── best_accuracy.pdopt
├── best_accuracy.pdparams
├── best_accuracy.states
├── config.yml
├── iter_epoch_3.pdopt
├── iter_epoch_3.pdparams
├── iter_epoch_3.states
├── latest.pdopt
├── latest.pdparams
├── latest.states
└── train.log
其中 bestaccuracy. 是评估集上的最优模型;iterepoch_x. 是以 save_epoch_step
为间隔保存下来的模型;latest.* 是最后一个epoch的模型。
总结:
如果需要训练自己的数据需要修改:
- 训练和评估数据路径(必须)
- 字典路径(必须)
- 预训练模型 (可选)
- 学习率、image shape、网络结构(可选)
4.2 模型评估
评估数据集可以通过 configs/rec/rec_icdar15_train.yml
修改Eval中的 label_file_path
设置。
这里默认使用 icdar2015 的评估集,加载刚刚训练好的模型权重:
!python tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy \
Global.character_dict_path=ppocr/utils/ic15_dict.txt
[2021/12/23 14:27:51] root INFO: Architecture :
[2021/12/23 14:27:51] root INFO: Backbone :
[2021/12/23 14:27:51] root INFO: model_name : large
[2021/12/23 14:27:51] root INFO: name : MobileNetV3
[2021/12/23 14:27:51] root INFO: scale : 0.5
[2021/12/23 14:27:51] root INFO: Head :
[2021/12/23 14:27:51] root INFO: fc_decay : 0
[2021/12/23 14:27:51] root INFO: name : CTCHead
[2021/12/23 14:27:51] root INFO: Neck :
[2021/12/23 14:27:51] root INFO: encoder_type : rnn
[2021/12/23 14:27:51] root INFO: hidden_size : 96
[2021/12/23 14:27:51] root INFO: name : SequenceEncoder
[2021/12/23 14:27:51] root INFO: Transform : None
[2021/12/23 14:27:51] root INFO: algorithm : CRNN
[2021/12/23 14:27:51] root INFO: model_type : rec
[2021/12/23 14:27:51] root INFO: Eval :
[2021/12/23 14:27:51] root INFO: dataset :
[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data
[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 14:27:51] root INFO: name : SimpleDataSet
[2021/12/23 14:27:51] root INFO: transforms :
[2021/12/23 14:27:51] root INFO: DecodeImage :
[2021/12/23 14:27:51] root INFO: channel_first : False
[2021/12/23 14:27:51] root INFO: img_mode : BGR
[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None
[2021/12/23 14:27:51] root INFO: RecResizeImg :
[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]
[2021/12/23 14:27:51] root INFO: KeepKeys :
[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 14:27:51] root INFO: loader :
[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256
[2021/12/23 14:27:51] root INFO: drop_last : False
[2021/12/23 14:27:51] root INFO: num_workers : 4
[2021/12/23 14:27:51] root INFO: shuffle : False
[2021/12/23 14:27:51] root INFO: use_shared_memory : False
[2021/12/23 14:27:51] root INFO: Global :
[2021/12/23 14:27:51] root INFO: cal_metric_during_train : True
[2021/12/23 14:27:51] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 14:27:51] root INFO: character_type : EN
[2021/12/23 14:27:51] root INFO: checkpoints : output/rec/ic15/best_accuracy
[2021/12/23 14:27:51] root INFO: debug : False
[2021/12/23 14:27:51] root INFO: distributed : False
[2021/12/23 14:27:51] root INFO: epoch_num : 72
[2021/12/23 14:27:51] root INFO: eval_batch_step : [0, 2000]
[2021/12/23 14:27:51] root INFO: infer_img : doc/imgs_words_en/word_10.png
[2021/12/23 14:27:51] root INFO: infer_mode : False
[2021/12/23 14:27:51] root INFO: log_smooth_window : 20
[2021/12/23 14:27:51] root INFO: max_text_length : 25
[2021/12/23 14:27:51] root INFO: pretrained_model : None
[2021/12/23 14:27:51] root INFO: print_batch_step : 10
[2021/12/23 14:27:51] root INFO: save_epoch_step : 3
[2021/12/23 14:27:51] root INFO: save_inference_dir : ./
[2021/12/23 14:27:51] root INFO: save_model_dir : ./output/rec/ic15/
[2021/12/23 14:27:51] root INFO: save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 14:27:51] root INFO: use_gpu : True
[2021/12/23 14:27:51] root INFO: use_space_char : False
[2021/12/23 14:27:51] root INFO: use_visualdl : False
[2021/12/23 14:27:51] root INFO: Loss :
[2021/12/23 14:27:51] root INFO: name : CTCLoss
[2021/12/23 14:27:51] root INFO: Metric :
[2021/12/23 14:27:51] root INFO: main_indicator : acc
[2021/12/23 14:27:51] root INFO: name : RecMetric
[2021/12/23 14:27:51] root INFO: Optimizer :
[2021/12/23 14:27:51] root INFO: beta1 : 0.9
[2021/12/23 14:27:51] root INFO: beta2 : 0.999
[2021/12/23 14:27:51] root INFO: lr :
[2021/12/23 14:27:51] root INFO: learning_rate : 0.0005
[2021/12/23 14:27:51] root INFO: name : Adam
[2021/12/23 14:27:51] root INFO: regularizer :
[2021/12/23 14:27:51] root INFO: factor : 0
[2021/12/23 14:27:51] root INFO: name : L2
[2021/12/23 14:27:51] root INFO: PostProcess :
[2021/12/23 14:27:51] root INFO: name : CTCLabelDecode
[2021/12/23 14:27:51] root INFO: Train :
[2021/12/23 14:27:51] root INFO: dataset :
[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data/
[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 14:27:51] root INFO: name : SimpleDataSet
[2021/12/23 14:27:51] root INFO: transforms :
[2021/12/23 14:27:51] root INFO: DecodeImage :
[2021/12/23 14:27:51] root INFO: channel_first : False
[2021/12/23 14:27:51] root INFO: img_mode : BGR
[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None
[2021/12/23 14:27:51] root INFO: RecResizeImg :
[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]
[2021/12/23 14:27:51] root INFO: KeepKeys :
[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 14:27:51] root INFO: loader :
[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256
[2021/12/23 14:27:51] root INFO: drop_last : True
[2021/12/23 14:27:51] root INFO: num_workers : 8
[2021/12/23 14:27:51] root INFO: shuffle : True
[2021/12/23 14:27:51] root INFO: use_shared_memory : False
[2021/12/23 14:27:51] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
[2021/12/23 14:27:51] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
W1223 14:27:51.861889 5192 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1223 14:27:51.865501 5192 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 14:27:56] root INFO: resume from output/rec/ic15/best_accuracy
[2021/12/23 14:27:56] root INFO: metric in ckpt ***************
[2021/12/23 14:27:56] root INFO: acc:0.48531535869041886
[2021/12/23 14:27:56] root INFO: norm_edit_dis:0.7895228681338454
[2021/12/23 14:27:56] root INFO: fps:3266.1877400927865
[2021/12/23 14:27:56] root INFO: best_epoch:24
[2021/12/23 14:27:56] root INFO: start_epoch:25
eval model:: 100%|████████████████████████████████| 9/9 [00:02<00:00, 3.32it/s]
[2021/12/23 14:27:59] root INFO: metric eval ***************
[2021/12/23 14:27:59] root INFO: acc:0.48531535869041886
[2021/12/23 14:27:59] root INFO: norm_edit_dis:0.7895228681338454
[2021/12/23 14:27:59] root INFO: fps:4491.015930181665
评估后,可以看到训练模型在验证集上的精度。
PaddleOCR支持训练和评估交替进行, 可在 configs/rec/rec_icdar15_train.yml
中修改 eval_batch_step
设置评估频率,默认每2000个iter评估一次。评估过程中默认将最佳acc模型,保存为 output/rec/ic15/best_accuracy
。
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
4.3 预测
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
预测图片:
默认预测图片存储在 infer_img
里,通过 -o Global.checkpoints
加载训练好的参数文件:
!python tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy Global.character_dict_path=ppocr/utils/ic15_dict.txt
[2021/12/23 14:29:19] root INFO: Architecture :
[2021/12/23 14:29:19] root INFO: Backbone :
[2021/12/23 14:29:19] root INFO: model_name : large
[2021/12/23 14:29:19] root INFO: name : MobileNetV3
[2021/12/23 14:29:19] root INFO: scale : 0.5
[2021/12/23 14:29:19] root INFO: Head :
[2021/12/23 14:29:19] root INFO: fc_decay : 0
[2021/12/23 14:29:19] root INFO: name : CTCHead
[2021/12/23 14:29:19] root INFO: Neck :
[2021/12/23 14:29:19] root INFO: encoder_type : rnn
[2021/12/23 14:29:19] root INFO: hidden_size : 96
[2021/12/23 14:29:19] root INFO: name : SequenceEncoder
[2021/12/23 14:29:19] root INFO: Transform : None
[2021/12/23 14:29:19] root INFO: algorithm : CRNN
[2021/12/23 14:29:19] root INFO: model_type : rec
[2021/12/23 14:29:19] root INFO: Eval :
[2021/12/23 14:29:19] root INFO: dataset :
[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data
[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
[2021/12/23 14:29:19] root INFO: name : SimpleDataSet
[2021/12/23 14:29:19] root INFO: transforms :
[2021/12/23 14:29:19] root INFO: DecodeImage :
[2021/12/23 14:29:19] root INFO: channel_first : False
[2021/12/23 14:29:19] root INFO: img_mode : BGR
[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None
[2021/12/23 14:29:19] root INFO: RecResizeImg :
[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]
[2021/12/23 14:29:19] root INFO: KeepKeys :
[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 14:29:19] root INFO: loader :
[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256
[2021/12/23 14:29:19] root INFO: drop_last : False
[2021/12/23 14:29:19] root INFO: num_workers : 4
[2021/12/23 14:29:19] root INFO: shuffle : False
[2021/12/23 14:29:19] root INFO: use_shared_memory : False
[2021/12/23 14:29:19] root INFO: Global :
[2021/12/23 14:29:19] root INFO: cal_metric_during_train : True
[2021/12/23 14:29:19] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt
[2021/12/23 14:29:19] root INFO: character_type : EN
[2021/12/23 14:29:19] root INFO: checkpoints : output/rec/ic15/best_accuracy
[2021/12/23 14:29:19] root INFO: debug : False
[2021/12/23 14:29:19] root INFO: distributed : False
[2021/12/23 14:29:19] root INFO: epoch_num : 72
[2021/12/23 14:29:19] root INFO: eval_batch_step : [0, 2000]
[2021/12/23 14:29:19] root INFO: infer_img : doc/imgs_words_en/word_19.png
[2021/12/23 14:29:19] root INFO: infer_mode : False
[2021/12/23 14:29:19] root INFO: log_smooth_window : 20
[2021/12/23 14:29:19] root INFO: max_text_length : 25
[2021/12/23 14:29:19] root INFO: pretrained_model : None
[2021/12/23 14:29:19] root INFO: print_batch_step : 10
[2021/12/23 14:29:19] root INFO: save_epoch_step : 3
[2021/12/23 14:29:19] root INFO: save_inference_dir : ./
[2021/12/23 14:29:19] root INFO: save_model_dir : ./output/rec/ic15/
[2021/12/23 14:29:19] root INFO: save_res_path : ./output/rec/predicts_ic15.txt
[2021/12/23 14:29:19] root INFO: use_gpu : True
[2021/12/23 14:29:19] root INFO: use_space_char : False
[2021/12/23 14:29:19] root INFO: use_visualdl : False
[2021/12/23 14:29:19] root INFO: Loss :
[2021/12/23 14:29:19] root INFO: name : CTCLoss
[2021/12/23 14:29:19] root INFO: Metric :
[2021/12/23 14:29:19] root INFO: main_indicator : acc
[2021/12/23 14:29:19] root INFO: name : RecMetric
[2021/12/23 14:29:19] root INFO: Optimizer :
[2021/12/23 14:29:19] root INFO: beta1 : 0.9
[2021/12/23 14:29:19] root INFO: beta2 : 0.999
[2021/12/23 14:29:19] root INFO: lr :
[2021/12/23 14:29:19] root INFO: learning_rate : 0.0005
[2021/12/23 14:29:19] root INFO: name : Adam
[2021/12/23 14:29:19] root INFO: regularizer :
[2021/12/23 14:29:19] root INFO: factor : 0
[2021/12/23 14:29:19] root INFO: name : L2
[2021/12/23 14:29:19] root INFO: PostProcess :
[2021/12/23 14:29:19] root INFO: name : CTCLabelDecode
[2021/12/23 14:29:19] root INFO: Train :
[2021/12/23 14:29:19] root INFO: dataset :
[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data/
[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
[2021/12/23 14:29:19] root INFO: name : SimpleDataSet
[2021/12/23 14:29:19] root INFO: transforms :
[2021/12/23 14:29:19] root INFO: DecodeImage :
[2021/12/23 14:29:19] root INFO: channel_first : False
[2021/12/23 14:29:19] root INFO: img_mode : BGR
[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None
[2021/12/23 14:29:19] root INFO: RecResizeImg :
[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]
[2021/12/23 14:29:19] root INFO: KeepKeys :
[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/23 14:29:19] root INFO: loader :
[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256
[2021/12/23 14:29:19] root INFO: drop_last : True
[2021/12/23 14:29:19] root INFO: num_workers : 8
[2021/12/23 14:29:19] root INFO: shuffle : True
[2021/12/23 14:29:19] root INFO: use_shared_memory : False
[2021/12/23 14:29:19] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)
W1223 14:29:19.803710 5290 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1223 14:29:19.807695 5290 device_context.cc:422] device: 0, cuDNN Version: 7.6.
[2021/12/23 14:29:25] root INFO: resume from output/rec/ic15/best_accuracy
[2021/12/23 14:29:25] root INFO: infer_img: doc/imgs_words_en/word_19.png
pred idx: Tensor(shape=[1, 25], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
[[29, 0 , 0 , 0 , 22, 0 , 0 , 0 , 25, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 33]])
[2021/12/23 14:29:25] root INFO: result: slow 0.8795223
[2021/12/23 14:29:25] root INFO: success!
得到输入图像的预测结果:
infer_img: doc/imgs_words_en/word_19.png
result: slow 0.8795223
作业
【题目1】
可视化出 PaddleOCR 中的实现的数据增强结果:noise、jitter, 并用语言解释效果。
可选测试图片:
https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_1.jpg
https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_2.jpg
【题目2】
更换 configs/rec/rec_icdar15_train.yml 配置中的 backbone 为 PaddleOCR 中的 ResNet34_vd,当输入图片shape为(3,32,100)时,Head 层最终输出的特征尺寸是多少?
【题目3】
下载10W中文数据集rec_data_lesson_demo,修改 configs/rec/rec_icdar15_train.yml 配置文件训练一个识别模型,提供训练log。
可加载预训练模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
总结
至此,一个基于CRNN的文本识别任务就全部完成了,更多功能和代码可以参考 PaddleOCR。
如果对项目任何问题或者疑问,欢迎在评论区留言提出