from typing import List from mdutils.mdutils import MdUtils from YQ_OCR.to_md.datasets import Dataset class TableMD(object): def __init__(self, img_name): self.img_name = img_name self.acc = 0 self.f = MdUtils(file_name='../output/' + self.img_name.split('.')[0] + '-表格识别结果') self.table_structure: List = ['原模型表格正确率', '新模型表格准确率'] self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致'] self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致'] self.write_header(f'表格识别结果测试报告') def write_header(self, title, level=1): self.f.new_header(level=level, title=title) def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'): def get_format_table_accuracy(str1, str2): n1 = len(str1) n2 = len(str2) if n1 == 0 or n2 == 0: return '' dp = [[0] * (n2 + 1) for _ in range(n1 + 1)] Max = 0 pos = 0 for i in range(1, n1 + 1): for j in range(1, n2 + 1): if str1[i - 1] == str2[j - 1]: dp[i][j] = dp[i - 1][j - 1] + 1 else: dp[i][j] = 0 if dp[i][j] > Max: Max = dp[i][j] pos = i - 1 return str1[pos - Max + 1:pos + 1] pre_list = ds.get_pre_list() gt_list = ds.get_gt_list() # print(pre_list) # print(gt_list) correct = 0 count = 0 n = len(pre_list) m = len(gt_list) if n < m: pre_list.extend(['' for _ in range(m - n)]) else: gt_list.extend(['' for _ in range(n - m)]) for x in range(len(gt_list)): gt_parse_list = gt_list[x].split('*') gt_parse_list.pop() pre_parse_list = pre_list[x].split('*') pre_parse_list.pop() # print(gt_parse_list) # print(pre_parse_list) n1 = len(pre_parse_list) m1 = len(gt_parse_list) # print(n1, m1) if n1 < m1: pre_parse_list.extend(['' for _ in range(m1 - n1)]) else: gt_parse_list.extend(['' for _ in range(n1 - m1)]) for j in range(len(gt_parse_list)): count += 1 # infer = get_format_table_accuracy(gt_list[x], pre_list[x]) if gt_parse_list[j] == pre_parse_list[j] or \ gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''): correct += 1 if key == 'new': self.new_table_text.extend( [f'{x + 1}行', gt_parse_list[j], pre_parse_list[j], '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌']) elif key == 'old': self.old_table_text.extend( [f'{x + 1}行', gt_parse_list[j], pre_parse_list[j], '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌']) acc = correct / count * 100 self.acc = acc if key == 'new': rows = len(self.new_table_text) // columns self.write_header(level=3, title=f'{self.img_name},' f'共检测{count}处,' f'正确{correct},' f'错误{count - correct},' f'表格正确率:{acc:.2f}%') self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align) elif key == 'old': rows = len(self.old_table_text) // columns self.f.new_header(level=3, title=f'{self.img_name},' f'共检测{count}处,' f'正确{correct},' f'错误{count - correct},' f'表格正确率:{acc:.2f}%') self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align) def get_table_accuracy(self): if self.acc < 0.6: with open('../output/worst.txt', 'a') as f: f.write(self.img_name + '\n') return self.acc