text2md.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import List
  2. from mdutils.mdutils import MdUtils
  3. from YQ_OCR.to_md.datasets import Dataset
  4. class TableMD(object):
  5. def __init__(self, img_name):
  6. self.img_name = img_name
  7. self.acc = 0
  8. self.f = MdUtils(file_name='../output/' + self.img_name.split('.')[0] + '-表格识别结果')
  9. self.table_structure: List = ['原模型表格正确率', '新模型表格准确率']
  10. self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致']
  11. self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致']
  12. self.write_header(f'表格识别结果测试报告')
  13. def write_header(self, title, level=1):
  14. self.f.new_header(level=level, title=title)
  15. def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'):
  16. def get_format_table_accuracy(str1, str2):
  17. n1 = len(str1)
  18. n2 = len(str2)
  19. if n1 == 0 or n2 == 0:
  20. return ''
  21. dp = [[0] * (n2 + 1) for _ in range(n1 + 1)]
  22. Max = 0
  23. pos = 0
  24. for i in range(1, n1 + 1):
  25. for j in range(1, n2 + 1):
  26. if str1[i - 1] == str2[j - 1]:
  27. dp[i][j] = dp[i - 1][j - 1] + 1
  28. else:
  29. dp[i][j] = 0
  30. if dp[i][j] > Max:
  31. Max = dp[i][j]
  32. pos = i - 1
  33. return str1[pos - Max + 1:pos + 1]
  34. pre_list = ds.get_pre_list()
  35. gt_list = ds.get_gt_list()
  36. # print(pre_list)
  37. # print(gt_list)
  38. correct = 0
  39. count = 0
  40. n = len(pre_list)
  41. m = len(gt_list)
  42. if n < m:
  43. pre_list.extend(['' for _ in range(m - n)])
  44. else:
  45. gt_list.extend(['' for _ in range(n - m)])
  46. for x in range(len(gt_list)):
  47. gt_parse_list = gt_list[x].split('*')
  48. gt_parse_list.pop()
  49. pre_parse_list = pre_list[x].split('*')
  50. pre_parse_list.pop()
  51. # print(gt_parse_list)
  52. # print(pre_parse_list)
  53. n1 = len(pre_parse_list)
  54. m1 = len(gt_parse_list)
  55. # print(n1, m1)
  56. if n1 < m1:
  57. pre_parse_list.extend(['' for _ in range(m1 - n1)])
  58. else:
  59. gt_parse_list.extend(['' for _ in range(n1 - m1)])
  60. for j in range(len(gt_parse_list)):
  61. count += 1
  62. # infer = get_format_table_accuracy(gt_list[x], pre_list[x])
  63. if gt_parse_list[j] == pre_parse_list[j] or \
  64. gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''):
  65. correct += 1
  66. if key == 'new':
  67. self.new_table_text.extend(
  68. [f'{x + 1}行',
  69. gt_parse_list[j],
  70. pre_parse_list[j],
  71. '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌'])
  72. elif key == 'old':
  73. self.old_table_text.extend(
  74. [f'{x + 1}行',
  75. gt_parse_list[j],
  76. pre_parse_list[j],
  77. '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌'])
  78. acc = correct / count * 100
  79. self.acc = acc
  80. if key == 'new':
  81. rows = len(self.new_table_text) // columns
  82. self.write_header(level=3,
  83. title=f'{self.img_name},'
  84. f'共检测{count}处,'
  85. f'正确{correct},'
  86. f'错误{count - correct},'
  87. f'表格正确率:{acc:.2f}%')
  88. self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align)
  89. elif key == 'old':
  90. rows = len(self.old_table_text) // columns
  91. self.f.new_header(level=3,
  92. title=f'{self.img_name},'
  93. f'共检测{count}处,'
  94. f'正确{correct},'
  95. f'错误{count - correct},'
  96. f'表格正确率:{acc:.2f}%')
  97. self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align)
  98. def get_table_accuracy(self):
  99. if self.acc < 0.6:
  100. with open('../output/worst.txt', 'a') as f:
  101. f.write(self.img_name + '\n')
  102. return self.acc