line_parser.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import math
  2. import numpy as np
  3. from dataclasses import dataclass
  4. # result 对象
  5. @dataclass
  6. class OcrResult(object):
  7. box: np.ndarray
  8. txt: str
  9. conf: float
  10. def __hash__(self):
  11. return hash(repr(self))
  12. def __repr__(self):
  13. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  14. @property
  15. def lt(self):
  16. l, t = np.min(self.box, 0)
  17. return [l, t]
  18. @property
  19. def rb(self):
  20. r, b = np.max(self.box, 0)
  21. return [r, b]
  22. @property
  23. def wh(self):
  24. l, t = self.lt
  25. r, b = self.rb
  26. return [r - l, b - t]
  27. @property
  28. def area(self):
  29. w, h = self.wh
  30. return w * h
  31. @property
  32. def is_slope(self):
  33. """
  34. function: 10~60,-60~-10度之间,需要旋转图片,因为目前的检测模型对于倾斜角度的不能检测
  35. return: 需要旋转的角度 ---> tan
  36. """
  37. p0 = self.box[0]
  38. p1 = self.box[1]
  39. if p0[0] == p1[0]: # 如果是正常的那就不用转
  40. return 0
  41. slope = 1. * (p1[1] - p0[1]) / (p1[0] - p0[0])
  42. return slope
  43. @property
  44. def center(self):
  45. l, t = self.lt
  46. r, b = self.rb
  47. return [(r + l) / 2, (b + t) / 2]
  48. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  49. y_idx = 0 + is_horizontal
  50. x_idx = 1 - y_idx
  51. if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
  52. if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
  53. eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
  54. dist = abs(self.center[y_idx] - b.center[y_idx])
  55. return dist < eps
  56. # 行处理器
  57. class LineParser(object):
  58. def __init__(self, ocr_raw_result, filters=None):
  59. # self.rotate_angle = 0
  60. if filters is None:
  61. filters = [lambda x: x.is_slope]
  62. self.ocr_res = []
  63. for re in ocr_raw_result:
  64. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  65. # if any([f(o) for f in filters]): continue
  66. self.ocr_res.append(o)
  67. self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
  68. self.eps = self.avg_height * 0.7
  69. @property
  70. def is_horizontal(self):
  71. res = self.ocr_res
  72. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  73. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  74. @property
  75. def avg_height(self):
  76. idx = self.is_horizontal + 0
  77. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  78. # 整体置信度
  79. @property
  80. def confidence(self):
  81. return np.mean([r.conf for r in self.ocr_res])
  82. # 处理器函数
  83. def parse(self, eps=40.0):
  84. # 存返回值
  85. res = []
  86. length = len(self.ocr_res)
  87. for i in range(length):
  88. res_i = self.ocr_res[i]
  89. if any(map(lambda x: res_i in x, res)): continue
  90. res_row = set()
  91. for j in range(i, length):
  92. res_j = self.ocr_res[j]
  93. if any(map(lambda x: res_j in x, res)): continue
  94. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  95. res_row.add(res_j)
  96. res.append(res_row)
  97. idx = self.is_horizontal + 0
  98. res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
  99. return res
  100. def detection_parse(self, eps=40.0):
  101. result = self.ocr_res
  102. if len(result) == 2:
  103. return result[0].one_line(result[1], True, self.eps)