direction.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import re
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Tuple
  5. import cv2
  6. import numpy as np
  7. from paddleocr import PaddleOCR
  8. from core.line_parser import LineParser
  9. class Directoin(Enum):
  10. TOP = 0
  11. RIGHT = 1
  12. BOTTOM = 2
  13. LEFT = 3
  14. # 父类
  15. class OcrAnchor(object):
  16. # 输入识别anchor的名字, 如身份证号
  17. def __init__(self, name: str, d: Directoin):
  18. self.name = name
  19. # anchor位置
  20. self.direction = d
  21. def t_func(anchor, c, is_horizontal):
  22. if is_horizontal:
  23. return 2 if anchor[1] > c[1] else 0
  24. else:
  25. return 3 if anchor[0] < c[0] else 1
  26. def b_func(anchor, c, is_horizontal):
  27. if is_horizontal:
  28. return 0 if anchor[1] > c[1] else 2
  29. else:
  30. return 1 if anchor[0] < c[0] else 3
  31. def l_func(anchor, c, is_horizontal):
  32. if is_horizontal:
  33. return 0 if anchor[0] < c[0] else 2
  34. else:
  35. return 1 if anchor[1] > c[1] else 3
  36. def r_func(anchor, c, is_horizontal):
  37. if is_horizontal:
  38. return 0 if anchor[0] > c[0] else 2
  39. else:
  40. return 1 if anchor[1] < c[1] else 3
  41. self.direction_funcs = {
  42. Directoin.TOP: t_func,
  43. Directoin.BOTTOM: b_func,
  44. Directoin.LEFT: l_func,
  45. Directoin.RIGHT: r_func,
  46. }
  47. # 获取中心区域坐标 -> (x, y)
  48. def get_rec_area(self, res) -> Tuple[float, float]:
  49. """获得整张身份证的识别区域, 返回识别区域的中心点"""
  50. boxes = []
  51. for row in res:
  52. for r in row:
  53. boxes.extend(r.box)
  54. boxes = np.stack(boxes)
  55. l, t = np.min(boxes, 0)
  56. r, b = np.max(boxes, 0)
  57. # 识别区域的box
  58. # big_box = [[l, t], [r, t], [r, b], [l, b]]
  59. # w, h = (r - l, b - t)
  60. return (l + r) / 2, (t + b) / 2
  61. # 判断是否是 锚点
  62. def is_anchor(self, txt, box) -> bool:
  63. pass
  64. # 找 锚点 -> 锚点坐标
  65. def find_anchor(self, res) -> Tuple[bool, float, float]:
  66. """
  67. 寻找身份证号的识别区域以及中心点
  68. """
  69. for row in res:
  70. for r in row:
  71. txt = r.txt.replace('-', '').replace(' ', '')
  72. box = r.box
  73. if self.is_anchor(txt, box):
  74. l, t = np.min(box, 0)
  75. r, b = np.max(box, 0)
  76. return True, (l + r) / 2, (t + b) / 2
  77. return False, 0., 0.
  78. # 定位 锚点 -> 角度
  79. # -> 锚点(x, y) pic(x, y) is_horizontal
  80. def locate_anchor(self, res, is_horizontal) -> int:
  81. found, id_cx, id_cy = self.find_anchor(res)
  82. # 如果识别不到身份证号
  83. if not found: raise Exception(f'识别不到anchor{self.name}')
  84. cx, cy = self.get_rec_area(res)
  85. # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
  86. # print(f'cx: {cx}, cy: {cy}')
  87. # 用k->get->func ==> f()
  88. f = self.direction_funcs.get(self.direction, None)
  89. return f((id_cx, id_cy), (cx, cy), is_horizontal)
  90. # if is_horizontal:
  91. # # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
  92. # return 0 if id_cy > cy else 2
  93. # else:
  94. # # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
  95. # return 1 if id_cx < cx else 3
  96. # 子类1 人像面
  97. class FrontSideAnchor(OcrAnchor):
  98. def __init__(self, name: str, d: Directoin):
  99. super(FrontSideAnchor, self).__init__(name, d)
  100. def is_anchor(self, txt, box) -> bool:
  101. txts = re.findall('\d{10,18}', txt)
  102. if len(txts) > 0:
  103. return True
  104. return False
  105. def locate_anchor(self, res, is_horizontal) -> int:
  106. return super(FrontSideAnchor, self).locate_anchor(res, is_horizontal)
  107. # 子类2 国徽面
  108. class BackSideAnchor(OcrAnchor):
  109. def __init__(self, name: str, d: Directoin):
  110. super(BackSideAnchor, self).__init__(name, d)
  111. def is_anchor(self, txt, box) -> bool:
  112. txt = txt.replace('.', '')
  113. txts = re.findall('有效期', txt)
  114. if len(txts) > 0:
  115. return True
  116. return False
  117. def locate_anchor(self, res, is_horizontal) -> int:
  118. return super(BackSideAnchor, self).locate_anchor(res, is_horizontal)
  119. # 调用以上 🔧工具
  120. # <- ocr_生数据
  121. # == ocr_熟数据(行处理后)
  122. # -> 角度0/1/2/3
  123. def detect_angle(result, ocr_anchor: OcrAnchor):
  124. lp = LineParser(result)
  125. res = lp.parse()
  126. print('------ angle ocr -------')
  127. print(res)
  128. print('------ angle ocr -------')
  129. is_horizontal = lp.is_horizontal
  130. return ocr_anchor.locate_anchor(res, is_horizontal)
  131. @dataclass
  132. class AngleDetector(object):
  133. """
  134. 角度检测器
  135. """
  136. ocr: PaddleOCR
  137. # 角度检测器
  138. # <- img(cv2格式) img_type
  139. # == result <- img(cv2)
  140. # -> angle result(ocr生)
  141. def detect_angle(self, img, image_type):
  142. image_type = int(image_type)
  143. # 初始化anchor对象
  144. ocr_anchor = BackSideAnchor('有效期', Directoin.BOTTOM) if image_type != 0 else FrontSideAnchor('身份证号',
  145. Directoin.BOTTOM)
  146. result = self.ocr.ocr(img, cls=True)
  147. try:
  148. angle = detect_angle(result, ocr_anchor)
  149. return angle, result
  150. except Exception as e:
  151. print(e)
  152. # 如果第一次识别不到,旋转90度再识别
  153. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  154. result = self.ocr.ocr(img, cls=True)
  155. angle = detect_angle(result, ocr_anchor)
  156. # 旋转90度之后要重新计算角度
  157. return (angle - 1 + 4) % 4, result