line_parser.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import numpy as np
  2. from dataclasses import dataclass
  3. # result 对象
  4. @dataclass
  5. class OcrResult(object):
  6. box: np.ndarray
  7. txt: str
  8. conf: float
  9. def __hash__(self):
  10. return hash(repr(self))
  11. def __repr__(self):
  12. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  13. @property
  14. def lt(self):
  15. l, t = np.min(self.box, 0)
  16. return [l, t]
  17. @property
  18. def rb(self):
  19. r, b = np.max(self.box, 0)
  20. return [r, b]
  21. @property
  22. def wh(self):
  23. l, t = self.lt
  24. r, b = self.rb
  25. return [r - l, b - t]
  26. @property
  27. def area(self):
  28. w, h = self.wh
  29. return w * h
  30. @property
  31. def is_slope(self):
  32. p0 = self.box[0]
  33. p1 = self.box[1]
  34. if p0[0] == p1[0]:
  35. return False
  36. slope = abs(1. * (p0[1] - p1[1]) / (p0[0] - p1[0]))
  37. return 0.4 < slope < 2.5
  38. @property
  39. def center(self):
  40. l, t = self.lt
  41. r, b = self.rb
  42. return [(r + l) / 2, (b + t) / 2]
  43. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  44. y_idx = 0 + is_horizontal
  45. x_idx = 1 - y_idx
  46. if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
  47. if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
  48. eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
  49. dist = abs(self.center[y_idx] - b.center[y_idx])
  50. return dist < eps
  51. # 行处理器
  52. class LineParser(object):
  53. def __init__(self, ocr_raw_result, filters=None):
  54. # if filters is None:
  55. # filters = [lambda x: x.is_slope]
  56. self.ocr_res = []
  57. for re in ocr_raw_result:
  58. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  59. if any([f(o) for f in filters]): continue
  60. self.ocr_res.append(o)
  61. # for f in filters:
  62. # self.ocr_res = list(filter(f, self.ocr_res))
  63. self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
  64. self.eps = self.avg_height * 0.7
  65. @property
  66. def is_horizontal(self):
  67. res = self.ocr_res
  68. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  69. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  70. @property
  71. def avg_height(self):
  72. idx = self.is_horizontal + 0
  73. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  74. # 整体置信度
  75. @property
  76. def confidence(self):
  77. return np.mean([r.conf for r in self.ocr_res])
  78. # 处理器函数
  79. # @sxtimeit
  80. def parse(self, eps=40.0):
  81. # 存返回值
  82. res = []
  83. # 需要 处理的 OcrResult 对象 的长度
  84. length = len(self.ocr_res)
  85. print('length: ', length)
  86. # 如果字段数 小于等于1 就抛出异常
  87. if length <= 1:
  88. raise Exception('无法识别')
  89. in_lines = set()
  90. # 遍历数组 并处理他
  91. for i in range(length):
  92. # print('in lines', in_lines)
  93. # 拿出 OcrResult对象的 第i值 -暂存-
  94. res_i = self.ocr_res[i]
  95. # 这次的 res_i 之前已经在结果集中,就继续下一个
  96. # if any(map(lambda x: res_i in x, res)): continue
  97. if i in in_lines: continue
  98. # set() -> {}
  99. # 初始化一个集合 即-输出-
  100. res_row = set()
  101. for j in range(i, length):
  102. res_j = self.ocr_res[j]
  103. # 这次的 res_i 之前已经在结果集中,就继续下一个
  104. # if any(map(lambda x: res_j in x, res)): continue
  105. if j in in_lines: continue
  106. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  107. # LineParser 对象 不可以直接加入字典
  108. res_row.add(res_j)
  109. in_lines.add(j)
  110. res.append(res_row)
  111. idx = self.is_horizontal + 0
  112. return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
  113. import numpy as np
  114. from dataclasses import dataclass
  115. # result 对象
  116. @dataclass
  117. class OcrResult(object):
  118. box: np.ndarray
  119. txt: str
  120. conf: float
  121. def __hash__(self):
  122. return hash(repr(self))
  123. def __repr__(self):
  124. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  125. @property
  126. def lt(self):
  127. l, t = np.min(self.box, 0)
  128. return [l, t]
  129. @property
  130. def rb(self):
  131. r, b = np.max(self.box, 0)
  132. return [r, b]
  133. @property
  134. def wh(self):
  135. l, t = self.lt
  136. r, b = self.rb
  137. return [r - l, b - t]
  138. @property
  139. def area(self):
  140. w, h = self.wh
  141. return w * h
  142. @property
  143. def is_slope(self):
  144. p0 = self.box[0]
  145. p1 = self.box[1]
  146. if p0[0] == p1[0]:
  147. return False
  148. slope = abs(1. * (p0[1] - p1[1]) / (p0[0] - p1[0]))
  149. return 0.4 < slope < 2.5
  150. @property
  151. def center(self):
  152. l, t = self.lt
  153. r, b = self.rb
  154. return [(r + l) / 2, (b + t) / 2]
  155. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  156. y_idx = 0 + is_horizontal
  157. x_idx = 1 - y_idx
  158. if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
  159. if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
  160. eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
  161. dist = abs(self.center[y_idx] - b.center[y_idx])
  162. return dist < eps
  163. # 行处理器
  164. class LineParser(object):
  165. def __init__(self, ocr_raw_result, filters=None):
  166. if filters is None:
  167. filters = [lambda x: x.is_slope]
  168. self.ocr_res = []
  169. for re in ocr_raw_result:
  170. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  171. if any([f(o) for f in filters]): continue
  172. self.ocr_res.append(o)
  173. # for f in filters:
  174. # self.ocr_res = list(filter(f, self.ocr_res))
  175. self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
  176. self.eps = self.avg_height * 0.7
  177. # self.ocr_res = []
  178. # for re in ocr_raw_result:
  179. # o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  180. # self.ocr_res.append(o)
  181. # self.eps = self.avg_height * 0.86
  182. @property
  183. def is_horizontal(self):
  184. res = self.ocr_res
  185. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  186. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  187. @property
  188. def avg_height(self):
  189. idx = self.is_horizontal + 0
  190. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  191. # 整体置信度
  192. @property
  193. def confidence(self):
  194. return np.mean([r.conf for r in self.ocr_res])
  195. # 处理器函数
  196. def parse(self, eps=40.0):
  197. # 存返回值
  198. res = []
  199. # 需要 处理的 OcrResult 对象 的长度
  200. length = len(self.ocr_res)
  201. # 如果字段数 小于等于1 就抛出异常
  202. if length <= 1:
  203. raise Exception('无法识别')
  204. # 遍历数组 并处理他
  205. for i in range(length):
  206. # 拿出 OcrResult对象的 第i值 -暂存-
  207. res_i = self.ocr_res[i]
  208. # 这次的 res_i 之前已经在结果集中,就继续下一个
  209. if any(map(lambda x: res_i in x, res)): continue
  210. # set() -> {}
  211. # 初始化一个集合 即-输出-
  212. res_row = set()
  213. for j in range(i, length):
  214. res_j = self.ocr_res[j]
  215. # 这次的 res_i 之前已经在结果集中,就继续下一个
  216. if any(map(lambda x: res_j in x, res)): continue
  217. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  218. # LineParser 对象 不可以直接加入字典
  219. res_row.add(res_j)
  220. res.append(res_row)
  221. idx = self.is_horizontal + 0
  222. return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])