ocr.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from dataclasses import dataclass
  2. import cv2
  3. import numpy as np
  4. import math
  5. from paddleocr import PaddleOCR, draw_ocr
  6. from core.direction import *
  7. from core.line_parser import LineParser
  8. from core.parser import *
  9. from PIL import Image
  10. @dataclass
  11. class BankOcr:
  12. ocr: PaddleOCR
  13. angle_detector: AngleDetector
  14. def predict(self, image: np.ndarray):
  15. image, angle, ori_result = self._pre_process(image)
  16. print(f'---------- detect angle: {angle} 角度 --------')
  17. # 这里使用自己训练的检测识别模型,在此之前,理想情况下,所有的银行卡的角度都已经是0,(正向)
  18. _, _, result = self._ocr(image)
  19. # self.imshow(image, result) # 将检测图片保存
  20. return self._post_process(result, angle)
  21. def imshow(self, image, result):
  22. img = Image.fromarray(image).convert("RGB")
  23. boxes = [line[0] for line in result]
  24. txts = [line[1][0] for line in result]
  25. scores = [line[1][1] for line in result]
  26. im_show = draw_ocr(img, boxes, txts, scores, font_path="./simfang.ttf")
  27. im_show = Image.fromarray(im_show)
  28. im_show.save("./img.jpg")
  29. def _pre_process(self, image: np.ndarray):
  30. angle, result = self.angle_detector.detect_angle(image)
  31. if angle == 1:
  32. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  33. if angle == 2:
  34. image = cv2.rotate(image, cv2.ROTATE_180)
  35. if angle == 3:
  36. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  37. # if -60 <= rotate_angle <= -20 or 20 <= rotate_angle <= 60:
  38. # print("需要旋转角度")
  39. # image = imutils.rotate(image, rotate_angle)
  40. # 因为有些img像素过大,导致检测框效果不好,识别就会出问题
  41. h, w, _ = image.shape
  42. h_ratio = 1 if h <= 1000 else h / 1000
  43. w_ratio = 1 if w <= 1000 else w / 1000
  44. if h_ratio == 1 and w_ratio == 1:
  45. return image, angle, result
  46. elif h_ratio != 1 or w_ratio != 1:
  47. ratio = h_ratio if h_ratio > w_ratio else w_ratio
  48. image = cv2.resize(image, (w // math.ceil(ratio), h // math.ceil(ratio)))
  49. print(image.shape)
  50. return image, angle, result
  51. def _ocr(self, image):
  52. # 获取模型检测结果,因为是正的照片了,所以不需要方向分类器
  53. result = self.ocr.ocr(image, cls=False)
  54. print("------------------")
  55. print("result:", result)
  56. print("------------------")
  57. # result=[] 就用官方再检测
  58. if not result:
  59. print("需要再次进行官方的检测代码。。。。。。。。。。。。")
  60. result = self.angle_detector.origin_detect(image)
  61. # 如果还是空,那就检测不出来
  62. if not result:
  63. raise Exception('经过两次检测都无法识别!!!')
  64. confs = [line[1][1] for line in result]
  65. txts = [line[1][0] for line in result]
  66. return txts, confs, result
  67. # result!=[] 就判断一些规则
  68. if result:
  69. confs = [line[1][1] for line in result]
  70. print("自己的检测模型得到的conf:", confs)
  71. # 根绝len(result)分规则判断
  72. if len(result) == 1:
  73. if confs[0] > 0.987:
  74. txts = [line[1][0] for line in result]
  75. return txts, confs, result
  76. else:
  77. print("len(result)=1时,再次用官方代码检测。。。。。。")
  78. result = self.angle_detector.origin_detect(image)
  79. elif len(result) == 2:
  80. # 1.判断两个检测框在不在一行
  81. is_oneline = self.angle_detector.det_oneline(result)
  82. # 2.如果不在一行
  83. if not is_oneline:
  84. txts = [line[1][0] for line in result]
  85. if not (any(map(lambda x: x > 0.987, confs)) and len(re.findall('\d{16,20}', txts)) > 0):
  86. print("len(result)=2,但是不在一行。。。。。。")
  87. result = self.angle_detector.origin_detect(image)
  88. # 3. 如果在一行
  89. elif is_oneline:
  90. if all(map(lambda x: x > 0.987, confs)):
  91. l_box, r_box = [], []
  92. l_box.extend(result[0][0])
  93. r_box.extend(result[1][0])
  94. l_max, _ = np.max(l_box, 0)
  95. r_min, _ = np.min(r_box, 0)
  96. if l_max > r_min:
  97. print("len(result)=2,在一行,但有重叠。。。。。。")
  98. result = self.angle_detector.origin_detect(image)
  99. else:
  100. print("len(result)=2,在一行,但有一个检测不行。。。。。。")
  101. result = self.angle_detector.origin_detect(image)
  102. elif len(result) > 2:
  103. print("len(result)=3,直接换官方检测。。。。。。")
  104. result = self.angle_detector.origin_detect(image)
  105. # elif len(result) == 2 and all(map(lambda x: x > 0.975, confs)):
  106. # l_box, r_box = [], []
  107. # l_box.extend(result[0][0])
  108. # r_box.extend(result[1][0])
  109. #
  110. # l_max, _ = np.max(l_box, 0)
  111. # r_min, _ = np.min(r_box, 0)
  112. #
  113. # if l_max > r_min:
  114. # print("说明自己的检测模型不好")
  115. # result = self.angle_detector.origin_detect(image)
  116. # else:
  117. # # 一般情况下,len=1
  118. # flag = 0
  119. # if all(map(lambda x: x >= 0.975, confs)):
  120. # flag = 1
  121. #
  122. # if flag == 0:
  123. # print("需要再次进行官方的检测代码。。。。。。。。。。。。")
  124. # result = self.angle_detector.origin_detect(image)
  125. # 如果还是空,那就检测不出来
  126. if not result:
  127. raise Exception('经过两次检测都无法识别!!!')
  128. confs = [line[1][1] for line in result]
  129. # 将检测到的文字放到一个列表中
  130. txts = [line[1][0] for line in result]
  131. return txts, confs, result
  132. def _post_process(self, raw_result, angle: int):
  133. # 把测试图片 喂给 OCR 返回给 self.raw_results
  134. line_parser = LineParser(raw_result)
  135. line_results = line_parser.parse()
  136. conf = line_parser.confidence
  137. parser = Parser(line_results)
  138. content = parser.parse()
  139. return {
  140. "confidence": conf,
  141. "orientation": angle,
  142. "number": content["number"].to_dict(),
  143. }