direction.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import re
  2. from dataclasses import dataclass
  3. from typing import Tuple
  4. import cv2
  5. import numpy as np
  6. from paddleocr import PaddleOCR
  7. from core.line_parser import LineParser
  8. class OcrAnchor(object):
  9. # 输入识别anchor的名字, 如身份证号
  10. def __init__(self, name: str):
  11. self.name = name
  12. def get_rec_area(self, res) -> Tuple[float, float]:
  13. """获得整张身份证的识别区域, 返回识别区域的中心点"""
  14. boxes = []
  15. for row in res:
  16. for r in row:
  17. boxes.extend(r.box)
  18. boxes = np.stack(boxes)
  19. l, t = np.min(boxes, 0)
  20. r, b = np.max(boxes, 0)
  21. # 识别区域的box
  22. # big_box = [[l, t], [r, t], [r, b], [l, b]]
  23. # w, h = (r - l, b - t)
  24. return (l + r) / 2, (t + b) / 2
  25. def is_anchor(self, txt, box) -> bool:
  26. pass
  27. def find_anchor(self, res) -> Tuple[bool, float, float]:
  28. """寻找身份证号的识别区域以及中心点,根据身份证的w > h判断是否水平"""
  29. for row in res:
  30. for r in row:
  31. txt = r.txt.replace('-', '').replace(' ', '')
  32. box = r.box
  33. if self.is_anchor(txt, box):
  34. l, t = np.min(box, 0)
  35. r, b = np.max(box, 0)
  36. return True, (l + r) / 2, (t + b) / 2
  37. return False, 0., 0.
  38. def locate_anchor(self, res, is_horizontal) -> int:
  39. found, id_cx, id_cy = self.find_anchor(res)
  40. # 如果识别不到身份证号
  41. if not found: raise Exception(f'识别不到anchor{self.name}')
  42. cx, cy = self.get_rec_area(res)
  43. # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
  44. # print(f'cx: {cx}, cy: {cy}')
  45. if is_horizontal:
  46. # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
  47. return 0 if id_cy > cy else 2
  48. else:
  49. # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
  50. return 1 if id_cx < cx else 3
  51. class FrontSideAnchor(OcrAnchor):
  52. def __init__(self, name: str):
  53. super(FrontSideAnchor, self).__init__(name)
  54. def is_anchor(self, txt, box) -> bool:
  55. txts = re.findall('\d{10,18}', txt)
  56. if len(txts) > 0:
  57. return True
  58. return False
  59. def locate_anchor(self, res, is_horizontal) -> int:
  60. return super(FrontSideAnchor, self).locate_anchor(res, is_horizontal)
  61. class BackSideAnchor(OcrAnchor):
  62. def __init__(self, name: str):
  63. super(BackSideAnchor, self).__init__(name)
  64. def is_anchor(self, txt, box) -> bool:
  65. txt = txt.replace('.', '')
  66. txts = re.findall('有效期', txt)
  67. if len(txts) > 0:
  68. return True
  69. return False
  70. def locate_anchor(self, res, is_horizontal) -> int:
  71. return super(BackSideAnchor, self).locate_anchor(res, is_horizontal)
  72. def detect_angle(result, ocr_anchor: OcrAnchor):
  73. lp = LineParser(result)
  74. res = lp.parse()
  75. print('------ angle ocr -------')
  76. print(res)
  77. print('------ angle ocr -------')
  78. is_horizontal = lp.is_horizontal
  79. return ocr_anchor.locate_anchor(res, is_horizontal)
  80. @dataclass
  81. # 角度检测器
  82. class AngleDetector(object):
  83. ocr: PaddleOCR
  84. def detect_angle(self, img, image_type):
  85. image_type = int(image_type)
  86. ocr_anchor = BackSideAnchor('有效期') if image_type != 0 else FrontSideAnchor('身份证号')
  87. result = self.ocr.ocr(img, cls=True)
  88. try:
  89. angle = detect_angle(result, ocr_anchor)
  90. return angle, result
  91. except Exception as e:
  92. print(e)
  93. # 如果第一次识别不到,旋转90度再识别
  94. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  95. result = self.ocr.ocr(img, cls=True)
  96. angle = detect_angle(result, ocr_anchor)
  97. # 旋转90度之后要重新计算角度
  98. return (angle - 1 + 4) % 4, result
  99. def _detect_back(self, image):
  100. mask = np.zeros(image.shape, dtype=np.uint8)
  101. gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  102. blur = cv2.GaussianBlur(gray, (3, 3), 0)
  103. adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
  104. cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  105. cnts = cnts[0] if len(cnts) == 2 else cnts[1]
  106. for c in cnts:
  107. area = cv2.contourArea(c)
  108. if area < 45000 and area > 20:
  109. cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
  110. mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
  111. h, w = mask.shape
  112. # Horizontal
  113. if w > h:
  114. left = mask[0:h, 0:0 + w // 2]
  115. right = mask[0:h, w // 2:]
  116. left_pixels = cv2.countNonZero(left)
  117. right_pixels = cv2.countNonZero(right)
  118. print(f'left: {left_pixels}, right: {right_pixels}')
  119. angle = 0 if left_pixels >= right_pixels else 2
  120. # Vertical
  121. else:
  122. top = mask[0:h // 2, 0:w]
  123. bottom = mask[h // 2:, 0:w]
  124. top_pixels = cv2.countNonZero(top)
  125. bottom_pixels = cv2.countNonZero(bottom)
  126. print(f'top: {top_pixels}, bottom: {bottom_pixels}')
  127. angle = 1 if bottom_pixels <= top_pixels else 3
  128. return angle, None