123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import copy
- from typing import List
- from mdutils.mdutils import MdUtils
- from YQ_OCR.utils.datasets import Dataset
- from YQ_OCR.utils.utils import Levenshtein_Distance
- class TableMD(object):
- def __init__(self, img_name):
- self.img_name = img_name
- self.acc = 0
- self.f = MdUtils(file_name='./output/' + self.img_name.split('.')[0] + '-表格识别结果')
- self.table_structure: List = ['原模型表格正确率', '新模型表格准确率']
- self.table_result: List = ['key值', '正确答案', 'ocr返回结果', '是否正确']
- self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致']
- self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致']
- self.write_header(f'测试结果报告')
- def write_header(self, title, level=1):
- self.f.new_header(level=level, title=title)
- def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'):
- def get_format_table_accuracy(str1, str2):
- n1 = len(str1)
- n2 = len(str2)
- if n1 == 0 or n2 == 0:
- return ''
- dp = [[0] * (n2 + 1) for _ in range(n1 + 1)]
- Max = 0
- pos = 0
- for i in range(1, n1 + 1):
- for j in range(1, n2 + 1):
- if str1[i - 1] == str2[j - 1]:
- dp[i][j] = dp[i - 1][j - 1] + 1
- else:
- dp[i][j] = 0
- if dp[i][j] > Max:
- Max = dp[i][j]
- pos = i - 1
- return str1[pos - Max + 1:pos + 1]
- pre_list = ds.get_pre_list()
- gt_list = ds.get_gt_list()
- # print(pre_list)
- # print(gt_list)
- correct = 0
- count = 0
- n = len(pre_list)
- m = len(gt_list)
- if n < m:
- pre_list.extend(['' for _ in range(m - n)])
- else:
- gt_list.extend(['' for _ in range(n - m)])
- for x in range(len(gt_list)):
- gt_parse_list = gt_list[x].split('*')
- gt_parse_list.pop()
- pre_parse_list = pre_list[x].split('*')
- pre_parse_list.pop()
- # print(gt_parse_list)
- # print(pre_parse_list)
- n1 = len(pre_parse_list)
- m1 = len(gt_parse_list)
- # print(n1, m1)
- if n1 < m1:
- pre_parse_list.extend(['' for _ in range(m1 - n1)])
- else:
- gt_parse_list.extend(['' for _ in range(n1 - m1)])
- for j in range(len(gt_parse_list)):
- count += 1
- # infer = get_format_table_accuracy(gt_list[x], pre_list[x])
- if gt_parse_list[j] == pre_parse_list[j] or \
- gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''):
- correct += 1
- if key == 'new':
- self.new_table_text.extend(
- [f'{x + 1}行',
- gt_parse_list[j],
- pre_parse_list[j],
- '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
- pre_parse_list[j].replace(' ', '') else '❌'])
- elif key == 'old':
- self.old_table_text.extend(
- [f'{x + 1}行',
- gt_parse_list[j],
- pre_parse_list[j],
- '✅' if gt_parse_list[j] == pre_parse_list[j] or gt_parse_list[j].replace(' ', '') ==
- pre_parse_list[j].replace(' ', '') else '❌'])
- acc = correct / count * 100
- self.acc = acc
- if key == 'new':
- rows = len(self.new_table_text) // columns
- self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
- self.write_header(level=3, title=f'共检测{count}处,'
- f'正确{correct},'
- f'错误{count - correct}')
- self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align)
- elif key == 'old':
- rows = len(self.old_table_text) // columns
- self.write_header(level=2, title=f'表格识别正确率:{acc:.2f}%')
- self.write_header(level=3, title=f'共检测{count}处,'
- f'正确{correct},'
- f'错误{count - correct}')
- self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align)
- def get_table_accuracy(self):
- if self.acc < 0.6:
- with open('../output/worst.txt', 'a') as f:
- f.write(self.img_name + '\n')
- return self.acc
- # 比较两个json文件 并在md文件中写入对比结果
- def evaluate_one(self, xlsx_dict, res_dict):
- true_num = 0
- xlsx_dict_no_space: dict = copy.deepcopy(xlsx_dict)
- for index, text in xlsx_dict_no_space.items():
- if type(xlsx_dict_no_space[index]) is str:
- xlsx_dict_no_space[index] = text.replace(' ', '')
- elif type(xlsx_dict_no_space[index]) is list:
- for k, v in enumerate(xlsx_dict_no_space[index]):
- xlsx_dict_no_space[index][k] = v.replace(' ', '')
- # 有key值的比较
- for key_yes in res_dict:
- if type(res_dict[key_yes]) is str:
- if Levenshtein_Distance(res_dict[key_yes], xlsx_dict_no_space[key_yes]) == 0:
- self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
- true_num += 1
- else:
- self.table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
- # 无key值的比较
- key_no_dict = {}
- for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']):
- key_no_dict[key_no_xlsx_no_space] = []
- for key_no_res in res_dict['noKeyList']:
- key_no_dict[key_no_xlsx_no_space].append(
- (Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res))
- sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0])
- NoKey_min_distance = sort_NoKey[0][0]
- if NoKey_min_distance == 0:
- self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
- true_num += 1
- else:
- self.table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
- # 算正确率
- all_num = len(self.table_result) // 4 - 1
- rate = true_num / all_num * 100
- # all_rate.append(rate)
- statistics = f'共{all_num}个字段,正确{true_num}个,错误{all_num - true_num}个'
- self.write_header(level=2, title=f'文字识别正确率:{rate}')
- self.write_header(level=3, title=statistics)
- self.f.new_table(columns=4, rows=len(self.table_result) // 4, text=self.table_result, text_align='center')
- return rate, statistics
- # def evaluate_one(xlsx_dict, res_dict):
- # true_num = 0
- # # 有key值的比较
- # for key_yes in res_dict:
- # if type(res_dict[key_yes]) is str:
- # if Levenshtein_Distance(res_dict[key_yes], xlsx_dict[key_yes]) == 0:
- # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '✅'])
- # true_num += 1
- # else:
- # table_result.extend([key_yes, xlsx_dict[key_yes], res_dict[key_yes], '❌'])
- # # 无key值的比较
- # key_no_dict = {}
- # for key_no_xlsx in xlsx_dict['noKeyList']:
- # key_no_dict[key_no_xlsx] = []
- # for key_no_res in res_dict['noKeyList']:
- # key_no_dict[key_no_xlsx].append((Levenshtein_Distance(key_no_xlsx, key_no_res), key_no_res))
- # sort_NoKey = sorted(key_no_dict[key_no_xlsx], key=lambda x: x[0])
- # NoKey_min_distance = sort_NoKey[0][0]
- # if NoKey_min_distance == 0:
- # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '✅'])
- # true_num += 1
- # else:
- # table_result.extend(['无key值', key_no_xlsx, sort_NoKey[0][1], '❌'])
- # # 算正确率
- # rate = true_num / (len(table_result) / 4)
- # all_rate.append(rate)
- # statistics = f'共{len(table_result) // 4}个字段,正确{true_num}个,错误{len(table_result) // 4 - true_num}个'
- # return "{:.2f}%".format(rate * 100), statistics
|