direction.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import re
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Tuple, List
  5. import cv2
  6. import numpy as np
  7. from paddleocr import PaddleOCR
  8. from core.line_parser import LineParser
  9. class Direction(Enum):
  10. TOP = 0
  11. RIGHT = 1
  12. BOTTOM = 2
  13. LEFT = 3
  14. # 父类
  15. class OcrAnchor(object):
  16. # 输入识别anchor的名字, 如身份证号
  17. def __init__(self, name: str, d: List[Direction]):
  18. self.name = name
  19. # anchor位置
  20. self.direction = d
  21. def t_func(anchor, c, is_horizontal):
  22. if is_horizontal:
  23. return 0 if anchor[1] < c[1] else 2
  24. else:
  25. return 1 if anchor[0] > c[0] else 3
  26. def l_func(anchor, c, is_horizontal):
  27. if is_horizontal:
  28. return 0 if anchor[0] < c[0] else 2
  29. else:
  30. return 1 if anchor[1] < c[1] else 3
  31. def b_func(anchor, c, is_horizontal):
  32. if is_horizontal:
  33. return 0 if anchor[1] > c[1] else 2
  34. else:
  35. return 1 if anchor[0] < c[0] else 3
  36. def r_func(anchor, c, is_horizontal):
  37. if is_horizontal:
  38. return 0 if anchor[0] > c[0] else 2
  39. else:
  40. return 1 if anchor[1] > c[1] else 3
  41. self.direction_funcs = {
  42. Direction.TOP: t_func,
  43. Direction.BOTTOM: b_func,
  44. Direction.LEFT: l_func,
  45. Direction.RIGHT: r_func,
  46. }
  47. # 获取中心区域坐标 -> (x, y)
  48. def get_rec_area(self, res) -> Tuple[float, float]:
  49. """获得整张身份证的识别区域, 返回识别区域的中心点"""
  50. boxes = []
  51. for row in res:
  52. for r in row:
  53. boxes.extend(r.box)
  54. boxes = np.stack(boxes)
  55. l, t = np.min(boxes, 0)
  56. r, b = np.max(boxes, 0)
  57. return (l + r) / 2, (t + b) / 2
  58. def is_anchor(self, txt, box) -> bool:
  59. pass
  60. def find_anchor(self, res) -> Tuple[bool, float, float]:
  61. """
  62. 寻找锚点 中心点坐标
  63. """
  64. for row in res:
  65. for r in row:
  66. txt = r.txt.replace('-', '').replace(' ', '')
  67. box = r.box
  68. if self.is_anchor(txt, box):
  69. l, t = np.min(box, 0)
  70. r, b = np.max(box, 0)
  71. return True, (l + r) / 2, (t + b) / 2
  72. return False, 0., 0.
  73. # 定位 锚点 -> 角度
  74. def locate_anchor(self, res, is_horizontal) -> int:
  75. found, id_cx, id_cy = self.find_anchor(res)
  76. # 如果识别不到身份证号
  77. if not found: raise Exception(f'识别不到anchor{self.name}')
  78. cx, cy = self.get_rec_area(res)
  79. # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
  80. # print(f'cx: {cx}, cy: {cy}')
  81. pre = None
  82. for d in self.direction:
  83. f = self.direction_funcs.get(d, None)
  84. angle = f((id_cx, id_cy), (cx, cy), is_horizontal)
  85. if pre is None:
  86. pre = angle
  87. else:
  88. if angle != pre:
  89. raise Exception('angle is not compatiable')
  90. return pre
  91. # 子类1 人像面
  92. class FrontSideAnchor(OcrAnchor):
  93. def __init__(self, name: str, d: List[Direction]):
  94. super(FrontSideAnchor, self).__init__(name, d)
  95. def is_anchor(self, txt, box) -> bool:
  96. txts = re.findall('\d{10,18}', txt)
  97. if len(txts) > 0:
  98. return True
  99. return False
  100. def locate_anchor(self, res, is_horizontal) -> int:
  101. return super(FrontSideAnchor, self).locate_anchor(res, is_horizontal)
  102. # 子类2 国徽面
  103. class BackSideAnchor(OcrAnchor):
  104. def __init__(self, name: str, d: List[Direction]):
  105. super(BackSideAnchor, self).__init__(name, d)
  106. def is_anchor(self, txt, box) -> bool:
  107. txt = txt.replace('.', '')
  108. txts = re.findall('有效期', txt)
  109. if len(txts) > 0:
  110. return True
  111. return False
  112. def locate_anchor(self, res, is_horizontal) -> int:
  113. return super(BackSideAnchor, self).locate_anchor(res, is_horizontal)
  114. def detect_angle(result, ocr_anchor: OcrAnchor):
  115. filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
  116. lp = LineParser(result, filters)
  117. res = lp.parse()
  118. print('------ angle ocr -------')
  119. print(res)
  120. print('------ angle ocr -------')
  121. is_horizontal = lp.is_horizontal
  122. return ocr_anchor.locate_anchor(res, is_horizontal)
  123. @dataclass
  124. class AngleDetector(object):
  125. """
  126. 角度检测器
  127. """
  128. ocr: PaddleOCR
  129. def detect_angle(self, img, image_type):
  130. image_type = int(image_type)
  131. ocr_anchor = BackSideAnchor('有效期', [Direction.BOTTOM]) if image_type != 0 else FrontSideAnchor('身份证号', [
  132. Direction.BOTTOM])
  133. result = self.ocr.ocr(img, cls=True)
  134. if not result: raise Exception("对不起,未识别到有效区域,请检查后上传,谢谢")
  135. try:
  136. angle = detect_angle(result, ocr_anchor)
  137. return angle, result
  138. except Exception as e:
  139. print(e)
  140. # 如果第一次识别不到,旋转90度再识别
  141. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  142. result = self.ocr.ocr(img, cls=True)
  143. angle = detect_angle(result, ocr_anchor)
  144. # 旋转90度之后要重新计算角度
  145. return (angle - 1 + 4) % 4, result