square_parser.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from typing import List
  2. import difflib
  3. import numpy as np
  4. from dataclasses import dataclass
  5. from blfe_core.line_parser import OcrResult
  6. # 确定key的第一行
  7. def get_key_fist_line(res_line_list, key):
  8. def string_similar(s1, s2):
  9. return difflib.SequenceMatcher(None, s1, s2).quick_ratio()
  10. # 需改动
  11. if key == '经营范围':
  12. print(res_line_list[-1][0])
  13. key_str = res_line_list[-1][0].split('市')[0].split('住所')[0].split('经营范围')[-1]
  14. print('key_str', key_str)
  15. else:
  16. key_str = res_line_list[-1][0].split(key)[-1]
  17. # title
  18. key_title = False
  19. key_title_list = []
  20. # print(res_line_list[:-1])
  21. for r in res_line_list[:-1]:
  22. # print(r.txt)
  23. if string_similar(r.txt, key) > 0.7:
  24. if len(r.txt) > len(key_str) + 2:
  25. box = r.box
  26. raw_w = box[1][0] - box[0][0]
  27. ratio = len(key) / len(r.txt)
  28. title_w = raw_w * ratio
  29. box[1][0] = box[0][0] + title_w
  30. box[2][0] = box[0][0] + title_w
  31. key_title = OcrResult(np.array(box), key, r.txt)
  32. break
  33. else:
  34. key_title = r
  35. break
  36. elif string_similar(r.txt, key) > 0.5 and len(r.txt) == 1:
  37. key_title_list.append(r)
  38. if key_title_list:
  39. key_title = key_title_list[-1]
  40. # 特殊处理
  41. if type(res_line_list[0]) == OcrResult and res_line_list[0].txt == '经营范围' and key == '经营范围':
  42. return res_line_list[1], key_title or res_line_list[1]
  43. max_num = 0
  44. max_or = None
  45. for rll_k, rll_v in enumerate(res_line_list[:-1]):
  46. m_num = string_similar(key_str, rll_v.txt)
  47. m_or = rll_v
  48. if m_num > max_num:
  49. max_num = m_num
  50. max_or = m_or
  51. max_or.txt = max_or.txt.split(key)[-1]
  52. return max_or, key_title if key_title else max_or
  53. def get_key_other_or(res_raw_list, key_heard: OcrResult, key_title):
  54. def h_range():
  55. h_list = []
  56. for key in keys_list:
  57. h_list.append(key.wh[1])
  58. mean_h = np.mean(h_list)
  59. h_range = (mean_h * 0, mean_h * 1.1)
  60. return h_range
  61. def is_title(r: OcrResult):
  62. left_len = h_range()[0] * 2
  63. r_point = [r.lt[0] - left_len, (r.lt[1] + r.wh[1]) / 2]
  64. title_list = []
  65. for res in res_raw_list:
  66. if res.lt[0] < r_point[0] < res.rb[0] and res.lt[1] < r_point[1] < res.rb[1]:
  67. title_list.append(res)
  68. if not title_list:
  69. return True
  70. for t in title_list:
  71. if t.txt == key_title.txt:
  72. return True
  73. return False
  74. def merge_box(boxes: List[OcrResult]):
  75. txt = boxes[0].txt
  76. box = boxes[0].box
  77. conf = boxes[0].conf
  78. for l_b in boxes[1:]:
  79. txt = txt + l_b.txt
  80. l, t = np.min(np.min([box, l_b.box], 0), 0)
  81. r, b = np.max(np.max([box, l_b.box], 0), 0)
  82. box = np.array([[l, t], [r, t], [r, b], [l, b]])
  83. conf = np.mean([conf, l_b.conf])
  84. return OcrResult(box, txt, conf)
  85. keys_list = [key_heard]
  86. x_line_list = [key_heard]
  87. anchor_key: OcrResult = key_heard
  88. for cell_y_k, cell_y_v in enumerate(res_raw_list):
  89. cell_x_line = []
  90. for cell_x_k, cell_x_v in enumerate(res_raw_list[cell_y_k:]):
  91. # cell 0<y<h的均值 竖直方向上
  92. # cell a_l<x<a_r 水平方向上 or
  93. if (h_range()[0] < (cell_x_v.center[1] - anchor_key.center[1]) < h_range()[1] and anchor_key.lt[0] <
  94. cell_x_v.center[0] < anchor_key.rb[0]) or \
  95. (abs(cell_x_v.center[1] - anchor_key.center[1]) < h_range()[1] and 0 < cell_x_v.lt[0] -
  96. anchor_key.rb[0] < h_range()[1] * 3):
  97. if is_title(cell_x_v):
  98. cell_x_line.append(cell_x_v)
  99. # 合并单元格
  100. if bool(cell_x_line):
  101. x_line_list.append(merge_box(cell_x_line))
  102. anchor_key = merge_box(cell_x_line)
  103. result = merge_box(x_line_list)
  104. return result
  105. def parser_xy(res_line, res_raw, key):
  106. # 在 res_line 中找到 key 对应的坐标
  107. key_row = []
  108. for row in res_line:
  109. print(row[-1])
  110. if key in row[-1][0]:
  111. key_row = row
  112. break
  113. if not bool(key_row): return
  114. key_heard, key_title = get_key_fist_line(key_row, key)
  115. return get_key_other_or(res_raw, key_heard, key_title)