123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from typing import List
- from mdutils.mdutils import MdUtils
- from YQ_OCR.to_md.datasets import Dataset
- class TableMD(object):
- def __init__(self, img_name):
- self.img_name = img_name
- self.acc = 0
- self.f = MdUtils(file_name='../output/' + self.img_name.split('.')[0] + '-表格识别结果')
- self.table_structure: List = ['原模型表格正确率', '新模型表格准确率']
- self.new_table_text: List = ['位置', '标注结果', '新模型推理', '是否一致']
- self.old_table_text: List = ['位置', '标注结果', '原模型推理', '是否一致']
- self.write_header(f'表格识别结果测试报告')
- def write_header(self, title, level=1):
- self.f.new_header(level=level, title=title)
- def write_table_accuracy(self, ds: Dataset, key, columns=4, text_align='center'):
- def get_format_table_accuracy(str1, str2):
- n1 = len(str1)
- n2 = len(str2)
- if n1 == 0 or n2 == 0:
- return ''
- dp = [[0] * (n2 + 1) for _ in range(n1 + 1)]
- Max = 0
- pos = 0
- for i in range(1, n1 + 1):
- for j in range(1, n2 + 1):
- if str1[i - 1] == str2[j - 1]:
- dp[i][j] = dp[i - 1][j - 1] + 1
- else:
- dp[i][j] = 0
- if dp[i][j] > Max:
- Max = dp[i][j]
- pos = i - 1
- return str1[pos - Max + 1:pos + 1]
- pre_list = ds.get_pre_list()
- gt_list = ds.get_gt_list()
- # print(pre_list)
- # print(gt_list)
- correct = 0
- count = 0
- n = len(pre_list)
- m = len(gt_list)
- if n < m:
- pre_list.extend(['' for _ in range(m - n)])
- else:
- gt_list.extend(['' for _ in range(n - m)])
- for x in range(len(gt_list)):
- gt_parse_list = gt_list[x].split('*')
- gt_parse_list.pop()
- pre_parse_list = pre_list[x].split('*')
- pre_parse_list.pop()
- # print(gt_parse_list)
- # print(pre_parse_list)
- n1 = len(pre_parse_list)
- m1 = len(gt_parse_list)
- # print(n1, m1)
- if n1 < m1:
- pre_parse_list.extend(['' for _ in range(m1 - n1)])
- else:
- gt_parse_list.extend(['' for _ in range(n1 - m1)])
- for j in range(len(gt_parse_list)):
- count += 1
- # infer = get_format_table_accuracy(gt_list[x], pre_list[x])
- if gt_parse_list[j] == pre_parse_list[j] or \
- gt_parse_list[j].replace(' ', '') == pre_parse_list[j].replace(' ', ''):
- correct += 1
- if key == 'new':
- self.new_table_text.extend(
- [f'{x + 1}行',
- gt_parse_list[j],
- pre_parse_list[j],
- '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌'])
- elif key == 'old':
- self.old_table_text.extend(
- [f'{x + 1}行',
- gt_parse_list[j],
- pre_parse_list[j],
- '✅' if gt_parse_list[j] == pre_parse_list[j] else '❌'])
- acc = correct / count * 100
- self.acc = acc
- if key == 'new':
- rows = len(self.new_table_text) // columns
- self.write_header(level=3,
- title=f'{self.img_name},'
- f'共检测{count}处,'
- f'正确{correct},'
- f'错误{count - correct},'
- f'表格正确率:{acc:.2f}%')
- self.f.new_table(columns=columns, rows=rows, text=self.new_table_text, text_align=text_align)
- elif key == 'old':
- rows = len(self.old_table_text) // columns
- self.f.new_header(level=3,
- title=f'{self.img_name},'
- f'共检测{count}处,'
- f'正确{correct},'
- f'错误{count - correct},'
- f'表格正确率:{acc:.2f}%')
- self.f.new_table(columns=columns, rows=rows, text=self.old_table_text, text_align=text_align)
- def get_table_accuracy(self):
- if self.acc < 0.6:
- with open('../output/worst.txt', 'a') as f:
- f.write(self.img_name + '\n')
- return self.acc
|