line_parser.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import numpy as np
  2. from dataclasses import dataclass
  3. # result 对象
  4. @dataclass
  5. class OcrResult(object):
  6. box: np.ndarray
  7. txt: str
  8. conf: float
  9. def __hash__(self):
  10. return hash(repr(self))
  11. def __repr__(self):
  12. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  13. @property
  14. def lt(self):
  15. l, t = np.min(self.box, 0)
  16. return [l, t]
  17. @property
  18. def rb(self):
  19. r, b = np.max(self.box, 0)
  20. return [r, b]
  21. @property
  22. def wh(self):
  23. l, t = self.lt
  24. r, b = self.rb
  25. return [r - l, b - t]
  26. @property
  27. def center(self):
  28. l, t = self.lt
  29. r, b = self.rb
  30. return [(r + l)/2, (b + t)/2]
  31. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  32. if is_horizontal:
  33. return abs(self.lt[1] - b.lt[1]) < eps
  34. else:
  35. return abs(self.rb[0] - b.rb[0]) < eps
  36. # 行处理器
  37. class LineParser(object):
  38. def __init__(self, ocr_raw_result):
  39. self.ocr_res = []
  40. for re in ocr_raw_result:
  41. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  42. self.ocr_res.append(o)
  43. self.eps = self.avg_height * 0.66
  44. @property
  45. def is_horizontal(self):
  46. res = self.ocr_res
  47. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  48. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  49. @property
  50. def avg_height(self):
  51. idx = self.is_horizontal + 0
  52. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  53. # 整体置信度
  54. @property
  55. def confidence(self):
  56. return np.mean([r.conf for r in self.ocr_res])
  57. # 处理器函数
  58. def parse(self, eps=40.0):
  59. # 存返回值
  60. res = []
  61. # 需要 处理的 OcrResult 对象 的长度
  62. length = len(self.ocr_res)
  63. # 如果字段数 小于等于1 就抛出异常
  64. if length <= 1:
  65. raise Exception('无法识别')
  66. # 遍历数组 并处理他
  67. for i in range(length):
  68. # 拿出 OcrResult对象的 第i值 -暂存-
  69. res_i = self.ocr_res[i]
  70. # any:-> True
  71. # -input: 可迭代对象 | -output: bool
  72. # -如果iterable的任何元素为true,则返回true。如果iterable为空,则返回false。 -与🚪-
  73. # map: -> [False, False, False, False, True, True, False, False]
  74. # -input: (函数, 可迭代对象) | -output: 可迭代对象
  75. # -把 res 喂给lambda --lambda返回True的值--> 输出 新的可迭代对象
  76. # 这次的 res_i 之前已经在结果集中,就继续下一个
  77. if any(map(lambda x: res_i in x, res)): continue
  78. # set() -> {}
  79. # 初始化一个集合 即-输出-
  80. res_row = set()
  81. for j in range(i, length):
  82. res_j = self.ocr_res[j]
  83. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  84. # LineParser 对象 不可以直接加入字典
  85. res_row.add(res_j)
  86. res.append(res_row)
  87. idx = self.is_horizontal + 0
  88. return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])