direction.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import re
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Tuple, List
  5. import cv2
  6. import numpy as np
  7. from paddleocr import PaddleOCR
  8. from core.line_parser import LineParser
  9. class Direction(Enum):
  10. TOP = 0
  11. RIGHT = 1
  12. BOTTOM = 2
  13. LEFT = 3
  14. # 父类
  15. class OcrAnchor(object):
  16. # 输入识别anchor的名字, 如身份证号
  17. def __init__(self, name: str, d: List[Direction]):
  18. self.name = name
  19. # anchor位置
  20. self.direction = d
  21. def t_func(anchor, c, is_horizontal):
  22. if is_horizontal:
  23. return 0 if anchor[1] < c[1] else 2
  24. else:
  25. return 1 if anchor[0] > c[0] else 3
  26. def l_func(anchor, c, is_horizontal):
  27. if is_horizontal:
  28. return 0 if anchor[0] < c[0] else 2
  29. else:
  30. return 1 if anchor[1] < c[1] else 3
  31. def b_func(anchor, c, is_horizontal):
  32. if is_horizontal:
  33. return 0 if anchor[1] > c[1] else 2
  34. else:
  35. return 1 if anchor[0] < c[0] else 3
  36. def r_func(anchor, c, is_horizontal):
  37. if is_horizontal:
  38. return 0 if anchor[0] > c[0] else 2
  39. else:
  40. return 1 if anchor[1] > c[1] else 3
  41. self.direction_funcs = {
  42. Direction.TOP: t_func,
  43. Direction.BOTTOM: b_func,
  44. Direction.LEFT: l_func,
  45. Direction.RIGHT: r_func,
  46. }
  47. # 获取中心区域坐标 -> (x, y)
  48. def get_rec_area(self, res) -> Tuple[float, float]:
  49. """获得整张身份证的识别区域, 返回识别区域的中心点"""
  50. boxes = []
  51. for row in res:
  52. for r in row:
  53. boxes.extend(r.box)
  54. boxes = np.stack(boxes)
  55. l, t = np.min(boxes, 0)
  56. r, b = np.max(boxes, 0)
  57. # 识别区域的box
  58. # big_box = [[l, t], [r, t], [r, b], [l, b]]
  59. # w, h = (r - l, b - t)
  60. return (l + r) / 2, (t + b) / 2
  61. # 判断是否是 锚点
  62. def is_anchor(self, txt, box) -> bool:
  63. pass
  64. # 找 锚点 -> 锚点坐标
  65. def find_anchor(self, res) -> Tuple[bool, float, float]:
  66. """
  67. 寻找锚点 中心点坐标
  68. """
  69. for row in res:
  70. for r in row:
  71. txt = r.txt.replace('-', '').replace(' ', '')
  72. box = r.box
  73. if self.is_anchor(txt, box):
  74. l, t = np.min(box, 0)
  75. r, b = np.max(box, 0)
  76. return True, (l + r) / 2, (t + b) / 2
  77. return False, 0., 0.
  78. # 定位 锚点 -> 角度
  79. # -> 锚点(x, y) pic(x, y) is_horizontal
  80. def locate_anchor(self, res, is_horizontal) -> int:
  81. found, id_cx, id_cy = self.find_anchor(res)
  82. # 如果识别不到身份证号
  83. if not found: raise Exception(f'识别不到anchor{self.name}')
  84. cx, cy = self.get_rec_area(res)
  85. # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
  86. # print(f'cx: {cx}, cy: {cy}')
  87. pre = None
  88. for d in self.direction:
  89. f = self.direction_funcs.get(d, None)
  90. angle = f((id_cx, id_cy), (cx, cy), is_horizontal)
  91. if pre is None:
  92. pre = angle
  93. else:
  94. if angle != pre:
  95. raise Exception('angle is not compatiable')
  96. return pre
  97. # if is_horizontal:
  98. # # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
  99. # return 0 if id_cy > cy else 2
  100. # else:
  101. # # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
  102. # return 1 if id_cx < cx else 3
  103. # 子类1 人像面
  104. class CETAnchor(OcrAnchor):
  105. def __init__(self, name: str, d: List[Direction]):
  106. super(CETAnchor, self).__init__(name, d)
  107. def is_anchor(self, txt, box) -> bool:
  108. txts = re.findall('全国大学英语', txt)
  109. if len(txts) > 0:
  110. return True
  111. return False
  112. def locate_anchor(self, res, is_horizontal) -> int:
  113. return super(CETAnchor, self).locate_anchor(res, is_horizontal)
  114. # 子类2 国徽面
  115. class TEMAnchor(OcrAnchor):
  116. def __init__(self, name: str, d: List[Direction]):
  117. super(TEMAnchor, self).__init__(name, d)
  118. def is_anchor(self, txt, box) -> bool:
  119. txts = re.findall('证书编号', txt)
  120. if len(txts) > 0:
  121. return True
  122. return False
  123. def locate_anchor(self, res, is_horizontal) -> int:
  124. return super(TEMAnchor, self).locate_anchor(res, is_horizontal)
  125. # 调用以上 🔧工具
  126. # <- ocr_生数据
  127. # == ocr_熟数据(行处理后)
  128. # -> 角度0/1/2/3
  129. def detect_angle(result, ocr_anchor: OcrAnchor):
  130. filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
  131. lp = LineParser(result, filters)
  132. res = lp.parse()
  133. print('------ angle ocr -------')
  134. print(res)
  135. print('------ angle ocr -------')
  136. is_horizontal = lp.is_horizontal
  137. return ocr_anchor.locate_anchor(res, is_horizontal)
  138. @dataclass
  139. class AngleDetector(object):
  140. """
  141. 角度检测器
  142. """
  143. ocr: PaddleOCR
  144. # 角度检测器
  145. # <- img(cv2格式) img_type
  146. # == result <- img(cv2)
  147. # -> angle result(ocr生)
  148. def detect_angle(self, img):
  149. # image_type = int(image_type)
  150. # result = self.ocr.ocr(img, cls=True)
  151. image_type, result = self.detect_img(img)
  152. ocr_anchor = CETAnchor('CET', [Direction.TOP]) if image_type == 0 else TEMAnchor('TEM', [
  153. Direction.BOTTOM])
  154. try:
  155. angle = detect_angle(result, ocr_anchor)
  156. return angle, result, image_type
  157. except Exception as e:
  158. print(e)
  159. # 如果第一次识别不到,旋转90度再识别
  160. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  161. result = self.ocr.ocr(img, cls=True)
  162. angle = detect_angle(result, ocr_anchor)
  163. # 旋转90度之后要重新计算角度
  164. return (angle - 1 + 4) % 4, result, image_type
  165. def detect_img(self, img):
  166. result = self.ocr.ocr(img, cls=True)
  167. for res in result:
  168. if "报告单" in res[1][0]:
  169. return 0, result
  170. raise Exception("不支持专四专八")
  171. # return 1, result