line_parser.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import numpy as np
  2. from dataclasses import dataclass
  3. # result 对象
  4. @dataclass
  5. class OcrResult(object):
  6. box: np.ndarray
  7. txt: str
  8. conf: float
  9. def __hash__(self):
  10. return hash(repr(self))
  11. def __repr__(self):
  12. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  13. @property
  14. def lt(self):
  15. l, t = np.min(self.box, 0)
  16. return [l, t]
  17. @property
  18. def rb(self):
  19. r, b = np.max(self.box, 0)
  20. return [r, b]
  21. @property
  22. def wh(self):
  23. l, t = self.lt
  24. r, b = self.rb
  25. return [r - l, b - t]
  26. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  27. if is_horizontal:
  28. return abs(self.lt[1] - b.lt[1]) < eps
  29. else:
  30. return abs(self.rb[0] - b.rb[0]) < eps
  31. # 行处理器
  32. class LineParser(object):
  33. def __init__(self, ocr_raw_result):
  34. self.ocr_res = []
  35. for re in ocr_raw_result:
  36. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  37. self.ocr_res.append(o)
  38. self.eps = self.avg_height * 0.66
  39. @property
  40. def is_horizontal(self):
  41. res = self.ocr_res
  42. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  43. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  44. @property
  45. def avg_height(self):
  46. idx = self.is_horizontal + 0
  47. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  48. @property
  49. def confidence(self):
  50. return np.mean([r.conf for r in self.ocr_res])
  51. # 处理器函数
  52. def parse(self, eps=40.0):
  53. # 存返回值
  54. res = []
  55. # 需要 处理的 OcrResult 对象 的长度
  56. length = len(self.ocr_res)
  57. # 如果字段数 小于等于1 就抛出异常
  58. if length <= 1:
  59. raise Exception('无法识别')
  60. # 遍历数组 并处理他
  61. for i in range(length):
  62. # 拿出 OcrResult对象的 第i值 -暂存-
  63. res_i = self.ocr_res[i]
  64. # any:-> True
  65. # -input: 可迭代对象 | -output: bool
  66. # -如果iterable的任何元素为true,则返回true。如果iterable为空,则返回false。 -与🚪-
  67. # map: -> [False, False, False, False, True, True, False, False]
  68. # -input: (函数, 可迭代对象) | -output: 可迭代对象
  69. # -把 res 喂给lambda --lambda返回True的值--> 输出 新的可迭代对象
  70. # 这次的 res_i 之前已经在结果集中,就继续下一个
  71. if any(map(lambda x: res_i in x, res)): continue
  72. # set() -> {}
  73. # 初始化一个集合 即-输出-
  74. res_row = set()
  75. for j in range(i, length):
  76. res_j = self.ocr_res[j]
  77. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  78. # LineParser 对象 不可以直接加入字典
  79. res_row.add(res_j)
  80. res.append(res_row)
  81. idx = self.is_horizontal + 0
  82. return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])