import copy from typing import List from mdutils.mdutils import MdUtils from YQ_OCR.utils.datasets import Dataset from YQ_OCR.utils.utils import Levenshtein_Distance class TableMD(object): def __init__(self, img_name): self.img_name = img_name self.acc = 0 self.f = MdUtils(file_name='./output/' + self.img_name.split('.')[0] + '-表格识别结果') self.table_structure: List = ['原模型表格正确率', '新模型表格准确率'] self.table_result: List = ['key值', '正确答案', 'ocr返回结果', '是否正确'] self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致'] self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致'] self.write_header(f'测试结果报告') def write_header(self, title, level=1): self.f.new_header(level=level, title=title) def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'): def get_format_table_accuracy(str1, str2): n1 = len(str1) n2 = len(str2) if n1 == 0 or n2 == 0: return '' dp = [[0] * (n2 + 1) for _ in range(n1 + 1)] Max = 0 pos = 0 for i in range(1, n1 + 1): for j in range(1, n2 + 1): if str1[i - 1] == str2[j - 1]: dp[i][j] = dp[i - 1][j - 1] + 1 else: dp[i][j] = 0 if dp[i][j] > Max: Max = dp[i][j] pos = i - 1 return str1[pos - Max + 1:pos + 1] pre_list = ds.get_pre_list() gt_list = ds.get_gt_list() # print(pre_list) # print(gt_list) correct = 0 count = 0 n = len(pre_list) m = len(gt_list) if n < m: pre_list.extend(['' for _ in range(m - n)]) else: gt_list.extend(['' for _ in range(n - m)]) for x in range(len(gt_list)): gt_parse_list = gt_list[x].split('*') gt_parse_list.pop() pre_parse_list = pre_list[x].split('*') pre_parse_list.pop() # print(gt_parse_list) # print(pre_parse_list) n1 = len(pre_parse_list) m1 = len(gt_parse_list) # print(n1, m1) if n1 < m1: pre_parse_list.extend(['' for _ in range(m1 - n1)]) else: gt_parse_list.extend(['' for _ in range(n1 - m1)]) for j in range(len(gt_parse_list)): count += 1 # infer = get_format_table_accuracy(gt_list[x], pre_list[x]) if gt_parse_list[j] == pre_parse_list[j] or \ gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''): correct += 1 if key == 'new': self.new_table_text.extend( [f'{x + 1}行', gt_parse_list[j], pre_parse_list[j], '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', '') else '❌']) elif key == 'old': self.old_table_text.extend( [f'{x + 1}行', gt_parse_list[j], pre_parse_list[j], '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', '') else '❌']) acc = correct / count * 100 self.acc = acc if key == 'new': rows = len(self.new_table_text) // columns self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%') self.write_header(level=3, title=f'共检测{count}处,' f'正确{correct},' f'错误{count - correct}') self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align) elif key == 'old': rows = len(self.old_table_text) // columns self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%') self.write_header(level=3, title=f'共检测{count}处,' f'正确{correct},' f'错误{count - correct}') self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align) def get_table_accuracy(self): if self.acc < 0.6: with open('../output/worst.txt', 'a') as f: f.write(self.img_name + '\n') return self.acc # 比较两个json文件 并在md文件中写入对比结果 def evaluate_one(self, xlsx_dict, res_dict): true_num = 0 xlsx_dict_no_space: dict = copy.deepcopy(xlsx_dict) for index, text in xlsx_dict_no_space.items(): if type(xlsx_dict_no_space[index]) is str: xlsx_dict_no_space[index] = text.replace(' ', '') elif type(xlsx_dict_no_space[index]) is list: for k, v in enumerate(xlsx_dict_no_space[index]): xlsx_dict_no_space[index][k] = v.replace(' ', '') # 有key值的比较 for key_yes in res_dict: if type(res_dict[key_yes]) is str: if Levenshtein_Distance(res_dict[key_yes], xlsx_dict_no_space[key_yes]) == 0: self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅']) true_num += 1 else: self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌']) # 无key值的比较 key_no_dict = {} for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']): key_no_dict[key_no_xlsx_no_space] = [] for key_no_res in res_dict['noKeyList']: key_no_dict[key_no_xlsx_no_space].append( (Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res)) sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0]) NoKey_min_distance = sort_NoKey[0][0] if NoKey_min_distance == 0: self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅']) true_num += 1 else: self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌']) # 算正确率 all_num = len(self.table_result) // 4 - 1 rate = true_num / all_num * 100 # all_rate.append(rate) statistics = f'共{all_num}个字段,正确{true_num}个,错误{all_num - true_num}个' self.write_header(level=2, title=f'文字识别正确率:{rate}') self.write_header(level=3, title=statistics) self.f.new_table(columns=4, rows=len(self.table_result) // 4, text=self.table_result, text_align='center') return rate, statistics # def evaluate_one(xlsx_dict, res_dict): # true_num = 0 # # 有key值的比较 # for key_yes in res_dict: # if type(res_dict[key_yes]) is str: # if Levenshtein_Distance(res_dict[key_yes], xlsx_dict[key_yes]) == 0: # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅']) # true_num += 1 # else: # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌']) # # 无key值的比较 # key_no_dict = {} # for key_no_xlsx in xlsx_dict['noKeyList']: # key_no_dict[key_no_xlsx] = [] # for key_no_res in res_dict['noKeyList']: # key_no_dict[key_no_xlsx].append((Levenshtein_Distance(key_no_xlsx, key_no_res), key_no_res)) # sort_NoKey = sorted(key_no_dict[key_no_xlsx], key=lambda x: x[0]) # NoKey_min_distance = sort_NoKey[0][0] # if NoKey_min_distance == 0: # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅']) # true_num += 1 # else: # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌']) # # 算正确率 # rate = true_num / (len(table_result) / 4) # all_rate.append(rate) # statistics = f'共{len(table_result) // 4}个字段,正确{true_num}个,错误{len(table_result) // 4 - true_num}个' # return "{:.2f}%".format(rate * 100), statistics