ocr.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from dataclasses import dataclass
  2. import cv2
  3. import numpy as np
  4. import math
  5. from paddleocr import PaddleOCR, draw_ocr
  6. from core.direction import *
  7. from core.line_parser import LineParser
  8. from core.parser import *
  9. from PIL import Image
  10. @dataclass
  11. class BankOcr:
  12. ocr: PaddleOCR
  13. angle_detector: AngleDetector
  14. def predict(self, image: np.ndarray):
  15. image, angle, ori_result = self._pre_process(image)
  16. print(f'---------- detect angle: {angle} 角度 --------')
  17. _, _, result = self._ocr(image)
  18. return self._post_process(result, angle)
  19. def imshow(self, image, result):
  20. img = Image.fromarray(image).convert("RGB")
  21. boxes = [line[0] for line in result]
  22. txts = [line[1][0] for line in result]
  23. scores = [line[1][1] for line in result]
  24. im_show = draw_ocr(img, boxes, txts, scores, font_path="./simfang.ttf")
  25. im_show = Image.fromarray(im_show)
  26. im_show.save("./img.jpg")
  27. def _pre_process(self, image: np.ndarray):
  28. angle, result = self.angle_detector.detect_angle(image)
  29. if angle == 1:
  30. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  31. if angle == 2:
  32. image = cv2.rotate(image, cv2.ROTATE_180)
  33. if angle == 3:
  34. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  35. h, w, _ = image.shape
  36. h_ratio = 1 if h <= 1000 else h / 1000
  37. w_ratio = 1 if w <= 1000 else w / 1000
  38. if h_ratio != 1 or w_ratio != 1:
  39. ratio = h_ratio if h_ratio > w_ratio else w_ratio
  40. image = cv2.resize(image, (w // math.ceil(ratio), h // math.ceil(ratio)))
  41. print(image.shape)
  42. return image, angle, result
  43. def _ocr(self, image):
  44. # 获取模型检测结果,因为是正的照片了,所以不需要方向分类器
  45. result = self.ocr.ocr(image, cls=False)
  46. print("------------------")
  47. print("result:", result)
  48. print("------------------")
  49. # result=[] 就用官方再检测
  50. if not result:
  51. result = self.angle_detector.origin_detect(image)
  52. # 如果还是空,那就检测不出来
  53. if not result:
  54. raise Exception('识别出错')
  55. confs = [line[1][1] for line in result]
  56. txts = [line[1][0] for line in result]
  57. return txts, confs, result
  58. if result:
  59. confs = [line[1][1] for line in result]
  60. if len(result) == 1:
  61. if confs[0] > 0.987:
  62. txts = [line[1][0] for line in result]
  63. return txts, confs, result
  64. else:
  65. result = self.angle_detector.origin_detect(image)
  66. elif len(result) == 2:
  67. is_oneline = self.angle_detector.det_oneline(result)
  68. if not is_oneline:
  69. txts = [line[1][0] for line in result]
  70. if not (any(map(lambda x: x > 0.987, confs)) and len(re.findall('\d{16,20}', txts)) > 0):
  71. result = self.angle_detector.origin_detect(image)
  72. elif is_oneline:
  73. if all(map(lambda x: x > 0.987, confs)):
  74. l_box, r_box = [], []
  75. l_box.extend(result[0][0])
  76. r_box.extend(result[1][0])
  77. l_max, _ = np.max(l_box, 0)
  78. r_min, _ = np.min(r_box, 0)
  79. if l_max > r_min:
  80. result = self.angle_detector.origin_detect(image)
  81. else:
  82. result = self.angle_detector.origin_detect(image)
  83. elif len(result) > 2:
  84. result = self.angle_detector.origin_detect(image)
  85. confs = [line[1][1] for line in result]
  86. txts = [line[1][0] for line in result]
  87. return txts, confs, result
  88. def _post_process(self, raw_result, angle: int):
  89. line_parser = LineParser(raw_result)
  90. line_results = line_parser.parse()
  91. conf = line_parser.confidence
  92. parser = Parser(line_results)
  93. content = parser.parse()
  94. return {
  95. "confidence": conf,
  96. "orientation": angle,
  97. "number": content["number"].to_dict(),
  98. }