direction.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import re
  2. import cv2
  3. import numpy as np
  4. from dataclasses import dataclass
  5. from enum import Enum
  6. from typing import Tuple, List
  7. import cv2
  8. from paddleocr import PaddleOCR
  9. from core.line_parser import LineParser
  10. import matplotlib.pyplot as plt
  11. # 枚举
  12. class Direction(Enum):
  13. TOP = 0
  14. RIGHT = 1
  15. BOTTOM = 2
  16. LEFT = 3
  17. # 父类
  18. class OcrAnchor(object):
  19. # anchor的名字, 如身份证号、承办人等
  20. def __init__(self, name: str, d: List[Direction]):
  21. self.name = name
  22. self.direction = d
  23. # 定义枚举字典
  24. def t_func(anchor, c, is_horizontal):
  25. if is_horizontal:
  26. return 0 if anchor[1] < c[1] else 2
  27. else:
  28. return 1 if anchor[0] > c[0] else 3
  29. def l_func(anchor, c, is_horizontal):
  30. if is_horizontal:
  31. return 0 if anchor[0] < c[0] else 2
  32. else:
  33. return 1 if anchor[1] < c[1] else 3
  34. def b_func(anchor, c, is_horizontal):
  35. if is_horizontal:
  36. return 0 if anchor[1] > c[1] else 2
  37. else:
  38. return 1 if anchor[0] < c[0] else 3
  39. def r_func(anchor, c, is_horizontal):
  40. if is_horizontal:
  41. return 0 if anchor[0] > c[0] else 2
  42. else:
  43. return 1 if anchor[1] > c[1] else 3
  44. self.direction_funcs = {
  45. Direction.TOP: t_func,
  46. Direction.LEFT: l_func,
  47. Direction.BOTTOM: b_func,
  48. Direction.RIGHT: r_func
  49. }
  50. # pic中心点
  51. def get_pic_center(self, res) -> Tuple[float, float]:
  52. boxs = []
  53. for row in res:
  54. for r in row:
  55. boxs.extend(r.box)
  56. boxs = np.stack(boxs)
  57. l, t = np.min(boxs, 0)
  58. r, b = np.max(boxs, 0)
  59. return (l + r) / 2, (t + b) / 2
  60. # 是否有锚点
  61. def is_anchor(self, txt, box):
  62. pass
  63. # 找锚点
  64. def find_anchor(self, res):
  65. for row in res:
  66. for r in row:
  67. if self.is_anchor(r.txt, r.box):
  68. # l, t = np.min(r.box, 0)
  69. # r, b = np.max(r.box, 0)
  70. # return True, (l + r) / 2, (t + b) / 2
  71. return True, r.center[0], r.center[1]
  72. return False, 0., 0.
  73. # get angle
  74. def locate_anchor(self, res, is_horizontal):
  75. found, a_cx, a_cy = self.find_anchor(res)
  76. cx, cy = self.get_pic_center(res)
  77. if found is False: raise Exception(f'识别不到anchor{self.name}')
  78. pre = None
  79. for d in self.direction:
  80. angle_func = self.direction_funcs.get(d, None)
  81. angle = angle_func((a_cx, a_cy), (cx, cy), is_horizontal)
  82. if pre is None:
  83. pre = angle
  84. else:
  85. if pre != angle:
  86. raise Exception('angle is not compatible')
  87. return pre
  88. # 子类1 户口本首页1
  89. class FrontAnchor(OcrAnchor):
  90. def __init__(self, name: str, d: List[Direction]):
  91. super(FrontAnchor, self).__init__(name, d)
  92. def is_anchor(self, txt, box):
  93. txts = re.findall('承办人', txt)
  94. if len(txts) > 0:
  95. return True
  96. return False
  97. def locate_anchor(self, res, is_horizontal):
  98. return super(FrontAnchor, self).locate_anchor(res, is_horizontal)
  99. # 子类2 常驻人口页0
  100. class PeopleAnchor(OcrAnchor):
  101. def __init__(self, name: str, d: List[Direction]):
  102. super(PeopleAnchor, self).__init__(name, d)
  103. def is_anchor(self, txt, box):
  104. txts = re.findall('常住', txt) or re.findall('登记卡', txt)
  105. if len(txts) > 0:
  106. return True
  107. return False
  108. def locate_anchor(self, res, is_horizontal):
  109. return super(PeopleAnchor, self).locate_anchor(res, is_horizontal)
  110. # 调用以上 🔧工具
  111. # <- ocr_生数据
  112. # == ocr_熟数据(行处理后)
  113. # -> 角度0/1/2/3
  114. def detect_angle(result, ocr_anchor: OcrAnchor):
  115. lp = LineParser(result)
  116. res = lp.parse()
  117. print('------ angle ocr -------')
  118. print(res)
  119. print('------ angle ocr -------')
  120. is_horizontal = lp.is_horizontal
  121. return ocr_anchor.locate_anchor(res, is_horizontal)
  122. @dataclass
  123. class AngleDetector(object):
  124. """
  125. 角度检测器
  126. """
  127. ocr: PaddleOCR
  128. # 角度检测器
  129. # <- img(cv2格式) img_type
  130. # == result <- img(cv2)
  131. # -> angle result(ocr生)
  132. def detect_angle(self, img, image_type):
  133. image_type = int(image_type)
  134. ocr_anchor = PeopleAnchor('常住', [Direction.TOP]) if image_type == 0 else FrontAnchor('承办人', [Direction.BOTTOM,
  135. Direction.LEFT])
  136. result = self.ocr.ocr(img, cls=True)
  137. # image = img.copy()
  138. # for box in result:
  139. # box = np.reshape(np.array(box[0]), [-1, 1, 2]).astype(np.int64)
  140. # image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  141. # cv2.imwrite("./test.jpg", image)
  142. try:
  143. angle = detect_angle(result, ocr_anchor)
  144. return angle, result
  145. except Exception as e:
  146. print(e)
  147. # 如果第一次识别不到,旋转90度再识别
  148. img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  149. result = self.ocr.ocr(img, cls=True)
  150. angle = detect_angle(result, ocr_anchor)
  151. # 旋转90度之后要重新计算角度
  152. return (angle - 1 + 4) % 4, result