1 项目介绍

1.1 项目功能

(1)项目功能:英文手写识别,如输入数据为手写英文作文扫描图片,技术:OCR技术
(2)应用场景:

  • 高考等应试教育英语作文电子阅卷
  • 英文手写电子笔记的上传

    1.2 评估指标

    (1)模型评估指标:动态规划实现的字符串相似度算法,公式如下
    截屏2020-12-22 上午11.25.53.png

    2 数据集介绍

    2.1 数据特征

    数据集中存在的问题和难点
    (1)数据集数量不够大,
    (2)扫描得到的图像倾斜,文字区域无法定位,或文字区域无法精准定位
    (3)图片中有很多噪声信息,如下划线
    (4)图片中手写英文存在很多连笔,涂改等。

截屏2020-12-22 上午11.29.28.png

3 数据的预处理

3.1 数据增强

数据集的预处理工作,为提升模型的性能,做了数据预处理工作。
(1)图像旋转
(2)图像缩放
(3)对图像添加噪声
(4)对图像进行模糊
(5)将图像往x、y方向上按指定的数量移动图像像素

截屏2020-12-22 上午11.32.24.png

3.2 倾斜矫正

扫描版或拍着图片会存在图像倾斜的情况,将大大降低识别效果,因此需要对图像进行倾斜矫正预处理。
截屏2020-12-22 上午11.34.04.png
原理:
(1)首先对图像进行边缘轮廓检测
(2)对边缘轮廓检测后的图片进行霍夫曼倾斜矫正
霍夫曼倾斜矫正原理:通过识别图像中的直线,检测直线倾斜角度和直线的位置信息对图像进行旋转,实测效果佳,且边缘轮廓检测对霍夫曼倾斜矫正起到很好的辅助作用。

3.3 去横线

图像中的横线对文字识别有一定的影响,因此需要在识别前对图像进行横线去除工作,去除横线方法调用Ieptonical库
原理:
(1)首先旋转图片进行倾斜矫正,使得横线变水平,然后提取出水平横线调用函数班背景去掉,只留下横线
(2)接着将横线进行阈值处理,高于阈值的横线加黑,低于阈值的变白,将处理图片上的黑色横线翻转为白色
(3)步骤2原图的横线被去掉,但原图人物身体的部分也被擦除
(4)此时调用相关函数使横线图片与人物擦除的图片想结合,补出擦除的部分,得到较好的去横线的效果。

3.4 文本区域定位

英文行全页面自动定位算法,文本区域定位,在输入神经网络模型前需要做文本区域定位,基于MSER算法进行改进。
算法原理:MSER算法产生的局部文字区域杂乱,对MSER产生的边框又进行了下面的四步筛选,大大提升了问题区域定位的效果
(1)首先根据矩形的大小,将过大或过小的矩形筛除掉
(2)将大矩形和小矩形如果交叠部分大于设定的阈值,将小矩形筛除掉
(3)此步特殊之处在于并不筛除掉矩阵,而是按照规则取min_left_top_x min_left_y max_right_bottom_x max_right_bottom_y 将一个类的矩阵合并成一个大的矩形
(4)按照矩形边框的height > min_height weight > min_weight筛选出最后的边框。

4 网络结构

(1)神经网络结构:四层卷积层+四层池化层
(2)神经网络使用的是:双向LSTM

截屏2020-12-22 下午9.00.48.png

(3)结构分析
该神经网络使用的是双向递归神经网络tf.nn.bidirectional_dynamic_rnn()
双向的RNN,当cell使用LSTM时,便是双向LSTMD。单向的RNN只考虑上文的信息对下文信息的影响,双向RNN即考虑当前信息不仅受到上文的影响,同时也考虑下文的影响。
前向RNN和dynamic_rnn完全一致,后向RNN输入的序列经过了反转。
(4)优化算法
本神经网络使用的参数优化算法:AdamOptimizer。除了该算法还有Momentum优化算法

  • Momentum优化算法计算梯度的指数加权平均,加快迭代速度
  • Adam算法集成了momentum动量梯度下降法和RMSprop梯度下降法的优点

(5)损失函数
CTC损失函数(connectionist temporal classification)
截屏2020-12-22 下午9.09.36.png
CTC在神经网络中计算一种损失值,主要用于可以对没有对齐的数据进行自动补齐,即主要是用在没有事先对齐的序列化数据训练上,应用领域如:语音识别、OCR识别

(6)池化
池化过程使用最大池化max_pool.
原因:虽然最大池化和平均池化都对数据进行了下采样,但是最大池化做特征选择,选出了分类识别度更好的特征

  • 最大池化:可以降低卷积层参数误差造成估计均值的偏移,更多的保留纹理信息
  • 最大池化提供了非线性,这是最大池化效果更好的原因

(7)使用Dropout
使用Dropout简化训练的网络结构,控制过拟合出现的风险,并通过调整,得到了一个比较合适的dropout参数
(8)激活函数
使用relu函数。可以降低网络参数训练过程中梯度消失或者梯度爆炸的风险
(9)断点续训
保证函数在训练中断后继续进行训练。

5 OCR实现

OCRGitHub源码下载

ocr_generated.py

  1. import os
  2. import glob
  3. import random
  4. import numpy as np
  5. from PIL import Image
  6. from PIL import ImageFilter
  7. #记录一个问题: tf.placeholder 报错InvalidArgumentError: You must feed a value for placeholder tensor 'inputs/x_input'
  8. #chr函数: 将数字转化成字符
  9. #ord函数: 将字符转化成数字
  10. #characterNo字典:a-z, A-Z, 0-10, " .,?\'-:;!/\"<>&(+" 为key分别对应值是0-25,26-51,52-61,62...
  11. #characters列表: 存储的是cahracterNo字典的key
  12. #建立characterNo字典的意思是: 为了将之后手写体对应的txt文件中的句子转化成 数字编码便于存储和运算求距离
  13. charactersNo={}
  14. characters=[]
  15. length=[]
  16. for i in range(26):
  17. charactersNo[chr(ord('a')+i)]=i
  18. characters.append(chr(ord('a')+i))
  19. for i in range(26):
  20. charactersNo[chr(ord('A')+i)]=i+26
  21. characters.append(chr(ord('A')+i))
  22. for i in range(10):
  23. charactersNo[chr(ord('0')+i)]=i+52
  24. characters.append(chr(ord('0')+i))
  25. punctuations=" .,?\'-:;!/\"<>&(+"
  26. for p in punctuations:
  27. charactersNo[p]=len(charactersNo)
  28. characters.append(p)
  29. def get_data():
  30. #读取了train_img和train_txt文件夹下的所有文件的读取路径
  31. #下面代码的作用是:
  32. #Imgs:列表结构 存储的是手写的英文图片
  33. #Y: 数组结构 存储的是图片对应的txt文件中句子,只不过存储的是字符转码后的数字
  34. #length: 数组结构 存储的是图片对应的txt文件中句子含有字符的数量
  35. imgFiles=glob.glob(os.path.join("train_img", "*"))
  36. imgFiles.sort()
  37. txtFiles=glob.glob(os.path.join("train_txt", "*"))
  38. txtFiles.sort()
  39. Imgs=[]
  40. Y=[]
  41. length=[]
  42. for i in range(len(imgFiles)):
  43. fin=open(txtFiles[i])
  44. line=fin.readlines()
  45. line=line[0]
  46. fin.close()
  47. y=np.asarray([0]*(len(line)))
  48. succ=True
  49. for j in range(len(line)):
  50. if line[j] not in charactersNo:
  51. succ=False
  52. break
  53. y[j]=charactersNo[line[j]]
  54. if not succ:
  55. continue
  56. Y.append(y)
  57. length.append(len(line))
  58. im = Image.open(imgFiles[i])
  59. width,height = im.size#1499,1386
  60. im = im.convert("L")
  61. Imgs.append(im)
  62. #np.asarray()函数 和 np.array()函数: 将list等结构转化成数组
  63. #区别是np.asarray()函数不是copy对象,而np.array()函数是copy对象
  64. print("train:",len(Imgs),len(Y))
  65. Y = np.asarray(Y)
  66. length = np.asarray(length)
  67. return Imgs, Y

ocr_forward.py

  1. import tensorflow as tf
  2. import os
  3. import glob
  4. import random
  5. import numpy as np
  6. from PIL import Image
  7. from PIL import ImageFilter
  8. import ocr_generated
  9. conv1_filter=32
  10. conv2_filter=64
  11. conv3_filter=128
  12. conv4_filter=256
  13. def get_weight(shape, regularizer):
  14. #参数w初始化,并且对w进行正则化处理,防止模型过拟合
  15. w = tf.Variable(tf.truncated_normal((shape), stddev=0.1, dtype=tf.float32))
  16. if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
  17. return w
  18. def get_bias(shape):
  19. #参数b初始化
  20. b = tf.Variable(tf.constant(0., shape=shape, dtype=tf.float32))
  21. return b
  22. def conv2d(x,w):
  23. #卷积层函数tf.nn.conv2d
  24. return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
  25. def max_pool_2x2(x, kernel_size):
  26. #池化层函数,在池化层采用最大池化,有效的提取特征
  27. return tf.nn.max_pool(x, ksize=kernel_size, strides=kernel_size, padding='VALID')
  28. def forward(x, train, regularizer):
  29. #前向传播中共使用了四层神经网络
  30. #第一层卷积层和池化层实现
  31. conv1_w = get_weight([3, 3, 1, conv1_filter], regularizer)
  32. conv1_b = get_bias([conv1_filter])
  33. conv1 = conv2d(x, conv1_w)
  34. relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_b))
  35. pool1 = max_pool_2x2(relu1, [1,2,2,1])
  36. #通过keep_prob参数控制drop_out函数对神经元的筛选
  37. if train:
  38. keep_prob = 0.6 #防止过拟合
  39. else:
  40. keep_prob = 1.0
  41. #第二层卷积层和池化层实现
  42. conv2_w = get_weight([5, 5, conv1_filter, conv2_filter], regularizer)
  43. conv2_b = get_bias([conv2_filter])
  44. conv2 = conv2d(tf.nn.dropout(pool1, keep_prob), conv2_w)
  45. relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_b))
  46. pool2 = max_pool_2x2(relu2, [1,2,1,1])
  47. #第三层卷积层和池化层
  48. conv3_w = get_weight([5, 5, conv2_filter, conv3_filter], regularizer)
  49. conv3_b = get_bias([conv3_filter])
  50. conv3 = conv2d(tf.nn.dropout(pool2, keep_prob), conv3_w)
  51. relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_b))
  52. pool3 = max_pool_2x2(relu3, [1,4,2,1])
  53. #第四层卷积层和池化层
  54. conv4_w = get_weight([5, 5, conv3_filter, conv4_filter], regularizer)
  55. conv4_b = get_bias([conv4_filter])
  56. conv4 = conv2d(tf.nn.dropout(pool3, keep_prob), conv4_w)
  57. relu4 = tf.nn.relu(tf.nn.bias_add(conv4, conv4_b))
  58. pool4 = max_pool_2x2(relu4, [1,7,1,1])
  59. rnn_inputs=tf.reshape(tf.nn.dropout(pool4,keep_prob),[-1,256,conv4_filter])
  60. num_hidden=512
  61. num_classes=len(ocr_generated.charactersNo)+1
  62. W = tf.Variable(tf.truncated_normal([num_hidden,num_classes],stddev=0.1), name="W")
  63. b = tf.Variable(tf.constant(0., shape=[num_classes]), name="b")
  64. #前向传播、反向传播,利用双向LSTM长时记忆循环网络
  65. #seq_len = tf.placeholder(tf.int32, shape=[None])
  66. #labels=tf.sparse_placeholder(tf.int32, shape=[None,2])
  67. cell_fw = tf.nn.rnn_cell.LSTMCell(num_hidden>>1, state_is_tuple=True)
  68. cell_bw = tf.nn.rnn_cell.LSTMCell(num_hidden>>1, state_is_tuple=True)
  69. #outputs_fw_bw: (output_fw, output_bw) 是(output_fw, output_bw)的元组
  70. outputs_fw_bw, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, rnn_inputs, dtype=tf.float32)
  71. #tf.contat 连接前向和反向得到的结果,在指定维度上进行连接
  72. outputs1 = tf.concat(outputs_fw_bw, 2)
  73. shape = tf.shape(x)
  74. batch_s, max_timesteps = shape[0], shape[1]
  75. outputs = tf.reshape(outputs1, [-1, num_hidden])
  76. #全连接层实现
  77. logits0 = tf.matmul(tf.nn.dropout(outputs,keep_prob), W) + b
  78. logits1 = tf.reshape(logits0, [batch_s, -1, num_classes])
  79. logits = tf.transpose(logits1, (1, 0, 2))
  80. y = tf.cast(logits, tf.float32)
  81. return y

ocr_backward.py

  1. import tensorflow as tf
  2. import ocr_forward
  3. import ocr_generated
  4. import os
  5. import glob
  6. import random
  7. import numpy as np
  8. from PIL import Image
  9. from PIL import ImageFilter
  10. REGULARIZER = 0.0001
  11. graphSize = (112,1024)
  12. MODEL_SAVE_PATH = "./model/"
  13. MODEL_NAME = "ocr_model"
  14. def transform(im, flag=True):
  15. '''
  16. 将传入的图片进行预处理:对图像进行图像缩放和数据增强
  17. Args:
  18. im : 传入的待处理的图片
  19. Return:
  20. graph : 返回经过预处理的图片
  21. #random.uniform(a, b)随机产生[a, b)之间的一个浮点数
  22. '''
  23. graph=np.zeros(graphSize[1]*graphSize[0]*1).reshape(graphSize[0],graphSize[1],1)
  24. deltaX=0
  25. deltaY=0
  26. ratio=1.464
  27. if flag:
  28. lowerRatio=max(1.269,im.size[1]*1.0/graphSize[0],im.size[0]*1.0/graphSize[1])
  29. upperRatio=max(lowerRatio,2.0)
  30. ratio=random.uniform(lowerRatio,upperRatio)
  31. deltaX=random.randint(0,int(graphSize[0]-im.size[1]/ratio))
  32. deltaY=random.randint(0,int(graphSize[1]-im.size[0]/ratio))
  33. else:
  34. ratio=max(1.464,im.size[1]*1.0/graphSize[0],im.size[0]*1.0/graphSize[1])
  35. deltaX=int(graphSize[0]-im.size[1]/ratio)>>1
  36. deltaY=int(graphSize[1]-im.size[0]/ratio)>>1
  37. height=int(im.size[1]/ratio)
  38. width=int(im.size[0]/ratio)
  39. data = im.resize((width,height),Image.ANTIALIAS).getdata()
  40. data = 1-np.asarray(data,dtype='float')/255.0
  41. data = data.reshape(height,width)
  42. graph[deltaX:deltaX+height,deltaY:deltaY+width,0]=data
  43. return graph
  44. def create_sparse(Y,dtype=np.int32):
  45. '''
  46. 对txt文本转化出来的数字序列Y作进一步的处理
  47. Args:
  48. Y
  49. Return:
  50. indices: 数组Y下标索引构成的新数组
  51. values: 下标索引对应的真实的数字码
  52. shape
  53. '''
  54. indices = []
  55. values = []
  56. for i in range(len(Y)):
  57. for j in range(len(Y[i])):
  58. indices.append((i,j))
  59. values.append(Y[i][j])
  60. indices = np.asarray(indices, dtype=np.int64)
  61. values = np.asarray(values, dtype=dtype)
  62. shape = np.asarray([len(Y), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) #[64,180]
  63. return (indices, values, shape)
  64. def backward():
  65. x = tf.placeholder(tf.float32, shape=[None, graphSize[0], graphSize[1],1])
  66. y = ocr_forward.forward(x, True, REGULARIZER)
  67. #y_: 表示真实标签数据
  68. #Y : 从文本中读取到的标签数据,训练时传给y_
  69. #y : 神经网络预测的标签
  70. global_step = tf.Variable(0, trainable=False)#全局步骤计数
  71. seq_len = tf.placeholder(tf.int32, shape=[None])
  72. y_ = tf.sparse_placeholder(tf.int32, shape=[None,2])
  73. Imgs, Y = ocr_generated.get_data()
  74. #损失函数使用的ctc_loss函数
  75. loss = tf.nn.ctc_loss(y_, y, seq_len)
  76. cost = tf.reduce_mean(loss)
  77. #优化函数使用的是Adam算法
  78. optimizer1 = tf.train.AdamOptimizer(learning_rate=0.0003).minimize(cost, global_step=global_step)
  79. optimizer2 = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, global_step=global_step)
  80. width1_decoded, width1_log_prob=tf.nn.ctc_beam_search_decoder(y, seq_len, merge_repeated=False,beam_width=1)
  81. decoded, log_prob = tf.nn.ctc_beam_search_decoder(y, seq_len, merge_repeated=False)
  82. width1_acc = tf.reduce_mean(tf.edit_distance(tf.cast(width1_decoded[0], tf.int32), y_))
  83. acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), y_))
  84. nBatchArray=np.arange(Y.shape[0])
  85. epoch=100
  86. batchSize=32
  87. saver=tf.train.Saver(max_to_keep=1)
  88. config = tf.ConfigProto()
  89. config.gpu_options.allow_growth = True
  90. sess=tf.Session(config=config)
  91. bestDevErr=100.0
  92. with sess:
  93. sess.run(tf.global_variables_initializer())
  94. ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
  95. if ckpt and ckpt.model_checkpoint_path:
  96. saver.restore(sess, ckpt.model_checkpoint_path)
  97. #saver.restore(sess, "model/model.ckpt")
  98. #print(outputs.get_shape())
  99. for ep in range(epoch):
  100. np.random.shuffle(nBatchArray)
  101. for i in range(0, Y.shape[0], batchSize):
  102. batch_output = create_sparse(Y[nBatchArray[i:i+batchSize]])
  103. X=[None]*min(Y.shape[0]-i,batchSize)
  104. for j in range(len(X)):
  105. X[j]=transform(Imgs[nBatchArray[i+j]])
  106. feed_dict={x:X,seq_len :np.ones(min(Y.shape[0]-i,batchSize)) * 256, y_:batch_output}
  107. if ep<50:
  108. sess.run(optimizer1, feed_dict=feed_dict)
  109. else:
  110. sess.run(optimizer2, feed_dict=feed_dict)
  111. print(ep,i,"loss:",tf.reduce_mean(loss.eval(feed_dict=feed_dict)).eval(),"err:",tf.reduce_mean(width1_acc.eval(feed_dict=feed_dict)).eval())
  112. #saver.save(sess, "model/model.ckpt")
  113. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME))
  114. def main():
  115. backward()
  116. if __name__ == '__main__':
  117. main()

ocr_test.py

  1. import os
  2. import glob
  3. import random
  4. import numpy as np
  5. from PIL import Image
  6. from PIL import ImageFilter
  7. import ocr_forward
  8. import tensorflow as tf
  9. REGULARIZER = 0.0001
  10. graphSize = (112,1024)
  11. def transform(im,flag=True):
  12. '''
  13. 对image做预处理,将其形状强制转化成(112, 1024, 1)的ndarray对象并返回
  14. Args:
  15. im = Image Object
  16. Return:
  17. graph = Ndarray Object
  18. '''
  19. graph=np.zeros(graphSize[1]*graphSize[0]*1).reshape(graphSize[0],graphSize[1],1)
  20. deltaX=0
  21. deltaY=0
  22. ratio=1.464
  23. if flag:
  24. lowerRatio=max(1.269,im.size[1]*1.0/graphSize[0],im.size[0]*1.0/graphSize[1])
  25. upperRatio=max(lowerRatio,1.659)
  26. ratio=random.uniform(lowerRatio,upperRatio)
  27. deltaX=random.randint(0,int(graphSize[0]-im.size[1]/ratio))
  28. deltaY=random.randint(0,int(graphSize[1]-im.size[0]/ratio))
  29. else:
  30. ratio=max(1.464,im.size[1]*1.0/graphSize[0],im.size[0]*1.0/graphSize[1])
  31. deltaX=int(graphSize[0]-im.size[1]/ratio)>>1
  32. deltaY=int(graphSize[1]-im.size[0]/ratio)>>1
  33. height=int(im.size[1]/ratio)
  34. width=int(im.size[0]/ratio)
  35. data = im.resize((width,height),Image.ANTIALIAS).getdata()
  36. data = 1-np.asarray(data,dtype='float')/255.0
  37. data = data.reshape(height,width)
  38. graph[deltaX:deltaX+height,deltaY:deltaY+width,0]=data
  39. return graph
  40. def countMargin(v,minSum,direction=True):
  41. '''
  42. Args:
  43. v = list
  44. minSum = Int
  45. Return:
  46. v中比minSum小的项数
  47. '''
  48. if direction:
  49. for i in range(len(v)):
  50. if v[i]>minSum:
  51. return i
  52. return len(v)
  53. for i in range(len(v)-1,-1,-1):
  54. if v[i]>minSum:
  55. return len(v)-i-1
  56. return len(v)
  57. def splitLine(seg,dataSum,h,maxHeight):
  58. i=0
  59. while i<len(seg)-1:
  60. if seg[i+1]-seg[i]<maxHeight:
  61. i+=1
  62. continue
  63. x=countMargin(dataSum[seg[i]:],3,True)
  64. y=countMargin(dataSum[:seg[i+1]],3,False)
  65. if seg[i+1]-seg[i]-x-y<maxHeight:
  66. i+=1
  67. continue
  68. idx=dataSum[seg[i]+x+h:seg[i+1]-h-y].argmin()+h
  69. if 0.33<=idx/(seg[i+1]-seg[i]-x-y)<=0.67:
  70. seg.insert(i+1,dataSum[seg[i]+x+h:seg[i+1]-y-h].argmin()+seg[i]+x+h)
  71. else:
  72. i+=1
  73. def getLine(im,data,upperbound=8,lowerbound=25,threshold=30,h=40,minHeight=35,maxHeight=120,beginX=20,endX=-20,beginY=125,endY=1100,merged=True):
  74. '''
  75. '''
  76. dataSum=data[:,beginX:endX].sum(1) #dataSum是一个一维向量
  77. lastPosition=beginY
  78. seg=[]
  79. flag=True
  80. cnt=0
  81. for i in range(beginY,endY):
  82. if dataSum[i]<=lowerbound:
  83. flag=True
  84. if dataSum[i]<=upperbound:
  85. cnt=0
  86. continue
  87. if flag:
  88. cnt+=1
  89. if cnt>=threshold:
  90. lineNo=np.argmin(dataSum[lastPosition:i])+lastPosition if threshold<=i-beginY else beginY
  91. if not merged or len(seg)==0 or lineNo-seg[-1]-countMargin(dataSum[seg[-1]:],5,True)-countMargin(dataSum[:lineNo],5,False)>minHeight:
  92. seg.append(lineNo)
  93. else:
  94. avg1=dataSum[max(0,seg[-1]-1):seg[-1]+2]
  95. avg1=avg1.sum()/avg1.shape[0]
  96. avg2=dataSum[max(0,lineNo-1):lineNo+2]
  97. avg2=avg2.sum()/avg2.shape[0]
  98. if avg1>avg2:
  99. seg[-1]=lineNo
  100. lastPosition=i
  101. flag=False
  102. lineNo=np.argmin(dataSum[lastPosition:]>10)+lastPosition if threshold<i else beginY
  103. if not merged or len(seg)==0 or lineNo-seg[-1]-countMargin(dataSum[seg[-1]:],10,True)-countMargin(dataSum[:lineNo],10,False)>minHeight:
  104. seg.append(lineNo)
  105. else:
  106. avg1=dataSum[max(0,seg[-1]-1):seg[-1]+2]
  107. avg1=avg1.sum()/avg1.shape[0]
  108. avg2=dataSum[max(0,lineNo-1):lineNo+2]
  109. avg2=avg2.sum()/avg2.shape[0]
  110. if avg1>avg2:
  111. seg[-1]=lineNo
  112. splitLine(seg,dataSum,h,maxHeight)
  113. results=[]
  114. for i in range(0,len(seg)-1):
  115. results.append(im.crop((0,seg[i]+countMargin(dataSum[seg[i]:],0),im.size[0],seg[i+1]-countMargin(dataSum[:seg[i+1]],0,False))))
  116. return results
  117. def calEditDistance(text1,text2):
  118. dp=np.asarray([0]*(len(text1)+1)*(len(text2)+1)).reshape(len(text1)+1,len(text2)+1)
  119. dp[0]=np.arange(len(text2)+1)
  120. dp[:,0]=np.arange(len(text1)+1)
  121. for i in range(1,len(text1)+1):
  122. for j in range(1,len(text2)+1):
  123. if text1[i-1]==text2[j-1]:
  124. dp[i,j]=dp[i-1,j-1]
  125. else:
  126. dp[i,j]=min(dp[i,j-1],dp[i-1,j],dp[i-1,j-1])+1
  127. return dp[-1,-1]
  128. def test():
  129. x = tf.placeholder(tf.float32, shape=[None, graphSize[0], graphSize[1], 1])
  130. y = ocr_forward.forward(x, False, REGULARIZER)
  131. seq_len = tf.placeholder(tf.int32, shape=[None])
  132. labels=tf.sparse_placeholder(tf.int32, shape=[None,2])
  133. loss = tf.nn.ctc_loss(labels, y, seq_len)
  134. cost = tf.reduce_mean(loss)
  135. width1_decoded, width1_log_prob=tf.nn.ctc_beam_search_decoder(y, seq_len, merge_repeated=False,beam_width=1)
  136. decoded, log_prob = tf.nn.ctc_beam_search_decoder(y, seq_len, merge_repeated=False)
  137. width1_acc = tf.reduce_mean(tf.edit_distance(tf.cast(width1_decoded[0], tf.int32), labels))
  138. acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))
  139. saver=tf.train.Saver(max_to_keep=1)
  140. result=0
  141. imgFiles=glob.glob(os.path.join("test_img","*"))
  142. imgFiles.sort()
  143. txtFiles=glob.glob(os.path.join("test_txt","*"))
  144. txtFiles.sort()
  145. for i in range(len(imgFiles)):
  146. goldLines=[]
  147. fin=open(txtFiles[i])
  148. lines=fin.readlines()
  149. fin.close()
  150. for j in range(len(lines)):
  151. goldLines.append(lines[j])
  152. im = Image.open(imgFiles[i])
  153. width, height = im.size
  154. im = im.convert("L")
  155. data = im.getdata()
  156. data = 1-np.asarray(data,dtype='float')/255.0
  157. data = data.reshape(height,width)
  158. #getLine()将图片切割成一行一行的词条
  159. Imgs = getLine(im,data)
  160. config = tf.ConfigProto()
  161. config.gpu_options.allow_growth = True
  162. sess=tf.Session(config=config)
  163. with sess:
  164. saver.restore(sess,"model/model.ckpt")
  165. X=[None]*len(Imgs)
  166. for j in range(len(Imgs)):
  167. X[j]=transform(Imgs[j],False)
  168. feed_dict={inputs:X,seq_len :np.ones(len(X)) * 256}
  169. predict = decoded[0].eval(feed_dict=feed_dict)
  170. j=0
  171. predict_text=""
  172. gold_text="".join(goldLines)
  173. for k in range(predict.dense_shape[0]):
  174. while j<len(predict.indices) and predict.indices[j][0]==k:
  175. predict_text+=characters[predict.values[j]]
  176. j+=1
  177. predict_text+="\n"
  178. predict_text=predict_text.rstrip("\n")
  179. print("predict_text:")
  180. print(predict_text)
  181. fout=open("predict%s%s.txt"%(os.sep,txtFiles[i][txtFiles[i].find(os.sep)+1:txtFiles[i].rfind('.')]),'w')
  182. fout.write(predict_text)
  183. fout.close()
  184. print("gold_text:")
  185. print(gold_text)
  186. cer=calEditDistance(predict_text,gold_text)*1.0/len(gold_text)
  187. print("预测正确率: ", end='')
  188. print(cer)
  189. print()
  190. result+=cer
  191. print("test composition err:",result*1.0/len(imgFiles))
  192. def main():
  193. test()
  194. if __name__ == '__main__':
  195. main()