|
@@ -0,0 +1,185 @@
|
|
|
+import copy
|
|
|
+from typing import List
|
|
|
+from mdutils.mdutils import MdUtils
|
|
|
+from YQ_OCR.utils.datasets import Dataset
|
|
|
+from YQ_OCR.utils.utils import Levenshtein_Distance
|
|
|
+
|
|
|
+
|
|
|
+class TableMD(object):
|
|
|
+ def __init__(self, img_name):
|
|
|
+ self.img_name = img_name
|
|
|
+ self.acc = 0
|
|
|
+ self.f = MdUtils(file_name='./output/' + self.img_name.split('.')[0] + '-表格识别结果')
|
|
|
+
|
|
|
+ self.table_structure: List = ['原模型表格正确率', '新模型表格准确率']
|
|
|
+ self.table_result: List = ['key值', '正确答案', 'ocr返回结果', '是否正确']
|
|
|
+ self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致']
|
|
|
+ self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致']
|
|
|
+ self.write_header(f'测试结果报告')
|
|
|
+
|
|
|
+ def write_header(self, title, level=1):
|
|
|
+ self.f.new_header(level=level, title=title)
|
|
|
+
|
|
|
+ def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'):
|
|
|
+ def get_format_table_accuracy(str1, str2):
|
|
|
+ n1 = len(str1)
|
|
|
+ n2 = len(str2)
|
|
|
+ if n1 == 0 or n2 == 0:
|
|
|
+ return ''
|
|
|
+ dp = [[0] * (n2 + 1) for _ in range(n1 + 1)]
|
|
|
+ Max = 0
|
|
|
+ pos = 0
|
|
|
+ for i in range(1, n1 + 1):
|
|
|
+ for j in range(1, n2 + 1):
|
|
|
+ if str1[i - 1] == str2[j - 1]:
|
|
|
+ dp[i][j] = dp[i - 1][j - 1] + 1
|
|
|
+ else:
|
|
|
+ dp[i][j] = 0
|
|
|
+ if dp[i][j] > Max:
|
|
|
+ Max = dp[i][j]
|
|
|
+ pos = i - 1
|
|
|
+ return str1[pos - Max + 1:pos + 1]
|
|
|
+
|
|
|
+ pre_list = ds.get_pre_list()
|
|
|
+ gt_list = ds.get_gt_list()
|
|
|
+ # print(pre_list)
|
|
|
+ # print(gt_list)
|
|
|
+ correct = 0
|
|
|
+ count = 0
|
|
|
+ n = len(pre_list)
|
|
|
+ m = len(gt_list)
|
|
|
+ if n < m:
|
|
|
+ pre_list.extend(['' for _ in range(m - n)])
|
|
|
+ else:
|
|
|
+ gt_list.extend(['' for _ in range(n - m)])
|
|
|
+
|
|
|
+ for x in range(len(gt_list)):
|
|
|
+ gt_parse_list = gt_list[x].split('*')
|
|
|
+ gt_parse_list.pop()
|
|
|
+ pre_parse_list = pre_list[x].split('*')
|
|
|
+ pre_parse_list.pop()
|
|
|
+ # print(gt_parse_list)
|
|
|
+ # print(pre_parse_list)
|
|
|
+ n1 = len(pre_parse_list)
|
|
|
+ m1 = len(gt_parse_list)
|
|
|
+ # print(n1, m1)
|
|
|
+ if n1 < m1:
|
|
|
+ pre_parse_list.extend(['' for _ in range(m1 - n1)])
|
|
|
+ else:
|
|
|
+ gt_parse_list.extend(['' for _ in range(n1 - m1)])
|
|
|
+
|
|
|
+ for j in range(len(gt_parse_list)):
|
|
|
+ count += 1
|
|
|
+ # infer = get_format_table_accuracy(gt_list[x], pre_list[x])
|
|
|
+ if gt_parse_list[j] == pre_parse_list[j] or \
|
|
|
+ gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''):
|
|
|
+ correct += 1
|
|
|
+ if key == 'new':
|
|
|
+ self.new_table_text.extend(
|
|
|
+ [f'{x + 1}行',
|
|
|
+ gt_parse_list[j],
|
|
|
+ pre_parse_list[j],
|
|
|
+ '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
|
|
|
+ pre_parse_list[j].replace(' ', '') else '❌'])
|
|
|
+ elif key == 'old':
|
|
|
+ self.old_table_text.extend(
|
|
|
+ [f'{x + 1}行',
|
|
|
+ gt_parse_list[j],
|
|
|
+ pre_parse_list[j],
|
|
|
+ '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
|
|
|
+ pre_parse_list[j].replace(' ', '') else '❌'])
|
|
|
+
|
|
|
+ acc = correct / count * 100
|
|
|
+ self.acc = acc
|
|
|
+ if key == 'new':
|
|
|
+ rows = len(self.new_table_text) // columns
|
|
|
+ self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
|
|
|
+ self.write_header(level=3, title=f'共检测{count}处,'
|
|
|
+ f'正确{correct},'
|
|
|
+ f'错误{count - correct}')
|
|
|
+ self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align)
|
|
|
+ elif key == 'old':
|
|
|
+ rows = len(self.old_table_text) // columns
|
|
|
+ self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
|
|
|
+ self.write_header(level=3, title=f'共检测{count}处,'
|
|
|
+ f'正确{correct},'
|
|
|
+ f'错误{count - correct}')
|
|
|
+ self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align)
|
|
|
+
|
|
|
+ def get_table_accuracy(self):
|
|
|
+ if self.acc < 0.6:
|
|
|
+ with open('../output/worst.txt', 'a') as f:
|
|
|
+ f.write(self.img_name + '\n')
|
|
|
+ return self.acc
|
|
|
+
|
|
|
+ # 比较两个json文件 并在md文件中写入对比结果
|
|
|
+ def evaluate_one(self, xlsx_dict, res_dict):
|
|
|
+ true_num = 0
|
|
|
+ xlsx_dict_no_space: dict = copy.deepcopy(xlsx_dict)
|
|
|
+ for index, text in xlsx_dict_no_space.items():
|
|
|
+ if type(xlsx_dict_no_space[index]) is str:
|
|
|
+ xlsx_dict_no_space[index] = text.replace(' ', '')
|
|
|
+ elif type(xlsx_dict_no_space[index]) is list:
|
|
|
+ for k, v in enumerate(xlsx_dict_no_space[index]):
|
|
|
+ xlsx_dict_no_space[index][k] = v.replace(' ', '')
|
|
|
+ # 有key值的比较
|
|
|
+ for key_yes in res_dict:
|
|
|
+ if type(res_dict[key_yes]) is str:
|
|
|
+ if Levenshtein_Distance(res_dict[key_yes], xlsx_dict_no_space[key_yes]) == 0:
|
|
|
+ self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
|
|
|
+ true_num += 1
|
|
|
+ else:
|
|
|
+ self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
|
|
|
+ # 无key值的比较
|
|
|
+ key_no_dict = {}
|
|
|
+ for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']):
|
|
|
+ key_no_dict[key_no_xlsx_no_space] = []
|
|
|
+ for key_no_res in res_dict['noKeyList']:
|
|
|
+ key_no_dict[key_no_xlsx_no_space].append(
|
|
|
+ (Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res))
|
|
|
+ sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0])
|
|
|
+ NoKey_min_distance = sort_NoKey[0][0]
|
|
|
+ if NoKey_min_distance == 0:
|
|
|
+ self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
|
|
|
+ true_num += 1
|
|
|
+ else:
|
|
|
+ self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
|
|
|
+
|
|
|
+ # 算正确率
|
|
|
+ all_num = len(self.table_result) // 4 - 1
|
|
|
+ rate = true_num / all_num * 100
|
|
|
+ # all_rate.append(rate)
|
|
|
+ statistics = f'共{all_num}个字段,正确{true_num}个,错误{all_num - true_num}个'
|
|
|
+ self.write_header(level=2, title=f'文字识别正确率:{rate}')
|
|
|
+ self.write_header(level=3, title=statistics)
|
|
|
+ self.f.new_table(columns=4, rows=len(self.table_result) // 4, text=self.table_result, text_align='center')
|
|
|
+ return rate, statistics
|
|
|
+
|
|
|
+ # def evaluate_one(xlsx_dict, res_dict):
|
|
|
+ # true_num = 0
|
|
|
+ # # 有key值的比较
|
|
|
+ # for key_yes in res_dict:
|
|
|
+ # if type(res_dict[key_yes]) is str:
|
|
|
+ # if Levenshtein_Distance(res_dict[key_yes], xlsx_dict[key_yes]) == 0:
|
|
|
+ # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
|
|
|
+ # true_num += 1
|
|
|
+ # else:
|
|
|
+ # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
|
|
|
+ # # 无key值的比较
|
|
|
+ # key_no_dict = {}
|
|
|
+ # for key_no_xlsx in xlsx_dict['noKeyList']:
|
|
|
+ # key_no_dict[key_no_xlsx] = []
|
|
|
+ # for key_no_res in res_dict['noKeyList']:
|
|
|
+ # key_no_dict[key_no_xlsx].append((Levenshtein_Distance(key_no_xlsx, key_no_res), key_no_res))
|
|
|
+ # sort_NoKey = sorted(key_no_dict[key_no_xlsx], key=lambda x: x[0])
|
|
|
+ # NoKey_min_distance = sort_NoKey[0][0]
|
|
|
+ # if NoKey_min_distance == 0:
|
|
|
+ # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
|
|
|
+ # true_num += 1
|
|
|
+ # else:
|
|
|
+ # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
|
|
|
+ # # 算正确率
|
|
|
+ # rate = true_num / (len(table_result) / 4)
|
|
|
+ # all_rate.append(rate)
|
|
|
+ # statistics = f'共{len(table_result) // 4}个字段,正确{true_num}个,错误{len(table_result) // 4 - true_num}个'
|
|
|
+ # return "{:.2f}%".format(rate * 100), statistics
|