direction.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import re
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Tuple, List
  5. import cv2
  6. import numpy as np
  7. from paddleocr import PaddleOCR
  8. from core.line_parser import LineParser
  9. from utils.time import timeit
  10. class Direction(Enum):
  11. TOP = 0
  12. RIGHT = 1
  13. BOTTOM = 2
  14. LEFT = 3
  15. class OcrAnchor(object):
  16. def __init__(self, name: str, d: List[Direction]):
  17. self.name = name
  18. # anchor位置
  19. self.direction = d
  20. def t_func(anchor, c, is_horizontal):
  21. if is_horizontal:
  22. return 0 if anchor[1] < c[1] else 2
  23. else:
  24. return 1 if anchor[0] > c[0] else 3
  25. def l_func(anchor, c, is_horizontal):
  26. if is_horizontal:
  27. return 0 if anchor[0] < c[0] else 2
  28. else:
  29. return 1 if anchor[1] < c[1] else 3
  30. def b_func(anchor, c, is_horizontal):
  31. if is_horizontal:
  32. return 0 if anchor[1] > c[1] else 2
  33. else:
  34. return 1 if anchor[0] < c[0] else 3
  35. def r_func(anchor, c, is_horizontal):
  36. if is_horizontal:
  37. return 0 if anchor[0] > c[0] else 2
  38. else:
  39. return 1 if anchor[1] > c[1] else 3
  40. self.direction_funcs = {
  41. Direction.TOP: t_func,
  42. Direction.BOTTOM: b_func,
  43. Direction.LEFT: l_func,
  44. Direction.RIGHT: r_func,
  45. }
  46. # 获取中心区域坐标 -> (x, y)
  47. def get_pic_center(self, res) -> Tuple[float, float]:
  48. """
  49. 获得整张图片的识别区域,
  50. 返回识别区域的中心点坐标
  51. """
  52. boxes = []
  53. for row in res:
  54. for r in row:
  55. boxes.extend(r.box)
  56. boxes = np.stack(boxes)
  57. l, t = np.min(boxes, 0)
  58. r, b = np.max(boxes, 0)
  59. return (l + r) / 2, (t + b) / 2
  60. def is_anchor(self, txt, box) -> bool:
  61. pass
  62. def find_anchor(self, res) -> Tuple[bool, float, float]:
  63. """
  64. 寻找锚点 中心点坐标
  65. """
  66. for row in res:
  67. for r in row:
  68. if self.is_anchor(r.txt, r.box):
  69. return True, r.center[0], r.center[1]
  70. return False, 0., 0.
  71. def locate_anchor(self, res, is_horizontal):
  72. found, a_cx, a_cy = self.find_anchor(res)
  73. cx, cy = self.get_pic_center(res)
  74. if found is False: raise Exception(f'识别不到anchor{self.name}')
  75. pre = None
  76. for d in self.direction:
  77. angle_func = self.direction_funcs.get(d, None)
  78. angle = angle_func((a_cx, a_cy), (cx, cy), is_horizontal)
  79. if pre is None:
  80. pre = angle
  81. else:
  82. if pre != angle:
  83. raise Exception('angle is not compatible')
  84. return pre
  85. # 子类0: 教育部学籍在线验证报告
  86. class ReportAnchor(OcrAnchor):
  87. def __init__(self, name: str, d: List[Direction]):
  88. super(ReportAnchor, self).__init__(name, d)
  89. def is_anchor(self, txt, box):
  90. txts = re.findall('查看该', txt) or re.findall('更新日期', txt)
  91. if len(txts) > 0:
  92. return True
  93. return False
  94. def locate_anchor(self, res, is_horizontal):
  95. return super(ReportAnchor, self).locate_anchor(res, is_horizontal)
  96. # 子类1: 教育部学历证书电子注册备案表
  97. class RecordAnchor(OcrAnchor):
  98. def __init__(self, name: str, d: List[Direction]):
  99. super(RecordAnchor, self).__init__(name, d)
  100. def is_anchor(self, txt, box):
  101. txts = re.findall('注册备案表', txt)
  102. if len(txts) > 0:
  103. return True
  104. return False
  105. def locate_anchor(self, res, is_horizontal):
  106. return super(RecordAnchor, self).locate_anchor(res, is_horizontal)
  107. # 子类2: 中国高等教育证书查询结果(零散查询)
  108. class ScattedAnchor(OcrAnchor):
  109. def __init__(self, name: str, d: List[Direction]):
  110. super(ScattedAnchor, self).__init__(name, d)
  111. def is_anchor(self, txt, box):
  112. txts = re.findall('教育学历', txt)
  113. if len(txts) > 0:
  114. return True
  115. return False
  116. def locate_anchor(self, res, is_horizontal):
  117. return super(ScattedAnchor, self).locate_anchor(res, is_horizontal)
  118. @timeit
  119. def detect_angle(result, ocr_anchor: OcrAnchor):
  120. lp = LineParser(result)
  121. res = lp.parse()
  122. print('------ angle ocr -------')
  123. print(res)
  124. print('------ angle ocr -------')
  125. is_horizontal = lp.is_horizontal
  126. return ocr_anchor.locate_anchor(res, is_horizontal)
  127. @dataclass
  128. class AngleDetector(object):
  129. """
  130. 角度检测器
  131. """
  132. ocr: PaddleOCR
  133. def detect_angle(self, img, image_type):
  134. image_type = int(image_type)
  135. if image_type == 0:
  136. ocr_anchor = ReportAnchor('0:教育部学历证书电子注册备案表', [Direction.TOP])
  137. elif image_type == 1:
  138. ocr_anchor = RecordAnchor('1:教育部学籍在线验证报告', [Direction.TOP])
  139. elif image_type == 2:
  140. ocr_anchor = ScattedAnchor('2:中国高等教育证书查询结果(零散查询)', [Direction.TOP])
  141. else:
  142. raise Exception('未传入 image_type')
  143. result = self.ocr.ocr(img, cls=True)
  144. try:
  145. angle = detect_angle(result, ocr_anchor)
  146. return angle, result
  147. except Exception as e:
  148. print(e)
  149. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  150. result = self.ocr.ocr(img, cls=True)
  151. angle = detect_angle(result, ocr_anchor)
  152. return (angle - 1 + 4) % 4, result