text2md.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import copy
  2. from typing import List
  3. from mdutils.mdutils import MdUtils
  4. from YQ_OCR.utils.datasets import Dataset
  5. from YQ_OCR.utils.utils import Levenshtein_Distance
  6. class TableMD(object):
  7. def __init__(self, img_name):
  8. self.img_name = img_name
  9. self.acc = 0
  10. self.f = MdUtils(file_name='./output/' + self.img_name.split('.')[0] + '-表格识别结果')
  11. self.table_structure: List = ['原模型表格正确率', '新模型表格准确率']
  12. self.table_result: List = ['key值', '正确答案', 'ocr返回结果', '是否正确']
  13. self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致']
  14. self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致']
  15. self.write_header(f'测试结果报告')
  16. def write_header(self, title, level=1):
  17. self.f.new_header(level=level, title=title)
  18. def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'):
  19. def get_format_table_accuracy(str1, str2):
  20. n1 = len(str1)
  21. n2 = len(str2)
  22. if n1 == 0 or n2 == 0:
  23. return ''
  24. dp = [[0] * (n2 + 1) for _ in range(n1 + 1)]
  25. Max = 0
  26. pos = 0
  27. for i in range(1, n1 + 1):
  28. for j in range(1, n2 + 1):
  29. if str1[i - 1] == str2[j - 1]:
  30. dp[i][j] = dp[i - 1][j - 1] + 1
  31. else:
  32. dp[i][j] = 0
  33. if dp[i][j] > Max:
  34. Max = dp[i][j]
  35. pos = i - 1
  36. return str1[pos - Max + 1:pos + 1]
  37. pre_list = ds.get_pre_list()
  38. gt_list = ds.get_gt_list()
  39. # print(pre_list)
  40. # print(gt_list)
  41. correct = 0
  42. count = 0
  43. n = len(pre_list)
  44. m = len(gt_list)
  45. if n < m:
  46. pre_list.extend(['' for _ in range(m - n)])
  47. else:
  48. gt_list.extend(['' for _ in range(n - m)])
  49. for x in range(len(gt_list)):
  50. gt_parse_list = gt_list[x].split('*')
  51. gt_parse_list.pop()
  52. pre_parse_list = pre_list[x].split('*')
  53. pre_parse_list.pop()
  54. # print(gt_parse_list)
  55. # print(pre_parse_list)
  56. n1 = len(pre_parse_list)
  57. m1 = len(gt_parse_list)
  58. # print(n1, m1)
  59. if n1 < m1:
  60. pre_parse_list.extend(['' for _ in range(m1 - n1)])
  61. else:
  62. gt_parse_list.extend(['' for _ in range(n1 - m1)])
  63. for j in range(len(gt_parse_list)):
  64. count += 1
  65. # infer = get_format_table_accuracy(gt_list[x], pre_list[x])
  66. if gt_parse_list[j] == pre_parse_list[j] or \
  67. gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''):
  68. correct += 1
  69. if key == 'new':
  70. self.new_table_text.extend(
  71. [f'{x + 1}行',
  72. gt_parse_list[j],
  73. pre_parse_list[j],
  74. '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
  75. pre_parse_list[j].replace(' ', '') else '❌'])
  76. elif key == 'old':
  77. self.old_table_text.extend(
  78. [f'{x + 1}行',
  79. gt_parse_list[j],
  80. pre_parse_list[j],
  81. '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
  82. pre_parse_list[j].replace(' ', '') else '❌'])
  83. acc = correct / count * 100
  84. self.acc = acc
  85. if key == 'new':
  86. rows = len(self.new_table_text) // columns
  87. self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
  88. self.write_header(level=3, title=f'共检测{count}处,'
  89. f'正确{correct},'
  90. f'错误{count - correct}')
  91. self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align)
  92. elif key == 'old':
  93. rows = len(self.old_table_text) // columns
  94. self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
  95. self.write_header(level=3, title=f'共检测{count}处,'
  96. f'正确{correct},'
  97. f'错误{count - correct}')
  98. self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align)
  99. def get_table_accuracy(self):
  100. if self.acc < 0.6:
  101. with open('../output/worst.txt', 'a') as f:
  102. f.write(self.img_name + '\n')
  103. return self.acc
  104. # 比较两个json文件 并在md文件中写入对比结果
  105. def evaluate_one(self, xlsx_dict, res_dict):
  106. true_num = 0
  107. xlsx_dict_no_space: dict = copy.deepcopy(xlsx_dict)
  108. for index, text in xlsx_dict_no_space.items():
  109. if type(xlsx_dict_no_space[index]) is str:
  110. xlsx_dict_no_space[index] = text.replace(' ', '')
  111. elif type(xlsx_dict_no_space[index]) is list:
  112. for k, v in enumerate(xlsx_dict_no_space[index]):
  113. xlsx_dict_no_space[index][k] = v.replace(' ', '')
  114. # 有key值的比较
  115. for key_yes in res_dict:
  116. if type(res_dict[key_yes]) is str:
  117. if Levenshtein_Distance(res_dict[key_yes], xlsx_dict_no_space[key_yes]) == 0:
  118. self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
  119. true_num += 1
  120. else:
  121. self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
  122. # 无key值的比较
  123. key_no_dict = {}
  124. for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']):
  125. key_no_dict[key_no_xlsx_no_space] = []
  126. for key_no_res in res_dict['noKeyList']:
  127. key_no_dict[key_no_xlsx_no_space].append(
  128. (Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res))
  129. sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0])
  130. NoKey_min_distance = sort_NoKey[0][0]
  131. if NoKey_min_distance == 0:
  132. self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
  133. true_num += 1
  134. else:
  135. self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
  136. # 算正确率
  137. all_num = len(self.table_result) // 4 - 1
  138. rate = true_num / all_num * 100
  139. # all_rate.append(rate)
  140. statistics = f'共{all_num}个字段,正确{true_num}个,错误{all_num - true_num}个'
  141. self.write_header(level=2, title=f'文字识别正确率:{rate}')
  142. self.write_header(level=3, title=statistics)
  143. self.f.new_table(columns=4, rows=len(self.table_result) // 4, text=self.table_result, text_align='center')
  144. return rate, statistics
  145. # def evaluate_one(xlsx_dict, res_dict):
  146. # true_num = 0
  147. # # 有key值的比较
  148. # for key_yes in res_dict:
  149. # if type(res_dict[key_yes]) is str:
  150. # if Levenshtein_Distance(res_dict[key_yes], xlsx_dict[key_yes]) == 0:
  151. # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
  152. # true_num += 1
  153. # else:
  154. # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
  155. # # 无key值的比较
  156. # key_no_dict = {}
  157. # for key_no_xlsx in xlsx_dict['noKeyList']:
  158. # key_no_dict[key_no_xlsx] = []
  159. # for key_no_res in res_dict['noKeyList']:
  160. # key_no_dict[key_no_xlsx].append((Levenshtein_Distance(key_no_xlsx, key_no_res), key_no_res))
  161. # sort_NoKey = sorted(key_no_dict[key_no_xlsx], key=lambda x: x[0])
  162. # NoKey_min_distance = sort_NoKey[0][0]
  163. # if NoKey_min_distance == 0:
  164. # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
  165. # true_num += 1
  166. # else:
  167. # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
  168. # # 算正确率
  169. # rate = true_num / (len(table_result) / 4)
  170. # all_rate.append(rate)
  171. # statistics = f'共{len(table_result) // 4}个字段,正确{true_num}个,错误{len(table_result) // 4 - true_num}个'
  172. # return "{:.2f}%".format(rate * 100), statistics