line_parser.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import numpy as np
  2. from dataclasses import dataclass
  3. # result 对象
  4. @dataclass
  5. class OcrResult(object):
  6. box: np.ndarray
  7. txt: str
  8. conf: float
  9. def __hash__(self):
  10. return hash(repr(self))
  11. def __repr__(self):
  12. return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
  13. @property
  14. def lt(self):
  15. l, t = np.min(self.box, 0)
  16. return [l, t]
  17. @property
  18. def rb(self):
  19. r, b = np.max(self.box, 0)
  20. return [r, b]
  21. @property
  22. def wh(self):
  23. l, t = self.lt
  24. r, b = self.rb
  25. return [r - l, b - t]
  26. @property
  27. def center(self):
  28. l, t = self.lt
  29. r, b = self.rb
  30. return [(r + l) / 2, (b + t) / 2]
  31. def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
  32. if is_horizontal:
  33. eps = 0.5 * (self.wh[1] + b.wh[1])
  34. return abs(self.center[1] - b.center[1]) < eps
  35. else:
  36. eps = 0.5 * (self.wh[0] + b.wh[0])
  37. return abs(self.center[0] - b.center[0]) < eps
  38. # 行处理器
  39. class LineParser(object):
  40. def __init__(self, ocr_raw_result):
  41. self.ocr_res = []
  42. for re in ocr_raw_result:
  43. o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
  44. self.ocr_res.append(o)
  45. self.eps = self.avg_height * 0.7
  46. @property
  47. def is_horizontal(self):
  48. res = self.ocr_res
  49. wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
  50. return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
  51. @property
  52. def avg_height(self):
  53. idx = self.is_horizontal + 0
  54. return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
  55. # 整体置信度
  56. @property
  57. def confidence(self):
  58. return np.mean([r.conf for r in self.ocr_res])
  59. # 处理器函数
  60. def parse(self, eps=40.0):
  61. # 存返回值
  62. res = []
  63. # 需要 处理的 OcrResult 对象 的长度
  64. length = len(self.ocr_res)
  65. # 如果字段数 小于等于1 就抛出异常
  66. if length <= 1:
  67. raise Exception('无法识别')
  68. # 遍历数组 并处理他
  69. for i in range(length):
  70. # 拿出 OcrResult对象的 第i值 -暂存-
  71. res_i = self.ocr_res[i]
  72. # any:-> True
  73. # -input: 可迭代对象 | -output: bool
  74. # -如果iterable的任何元素为true,则返回true。如果iterable为空,则返回false。 -与🚪-
  75. # map: -> [False, False, False, False, True, True, False, False]
  76. # -input: (函数, 可迭代对象) | -output: 可迭代对象
  77. # -把 res 喂给lambda --lambda返回True的值--> 输出 新的可迭代对象
  78. # 这次的 res_i 之前已经在结果集中,就继续下一个
  79. if any(map(lambda x: res_i in x, res)): continue
  80. # set() -> {}
  81. # 初始化一个集合 即-输出-
  82. res_row = set()
  83. for j in range(i, length):
  84. res_j = self.ocr_res[j]
  85. if res_i.one_line(res_j, self.is_horizontal, self.eps):
  86. # LineParser 对象 不可以直接加入字典
  87. res_row.add(res_j)
  88. res.append(res_row)
  89. # 对于身份证
  90. # 进入line_parser 时已经水平
  91. # 故 外层是按x排序 里层按y值排序
  92. return sorted([sorted(list(r), key=lambda x: x.lt[0]) for r in res], key=lambda x: x[0].lt[1])