|
@@ -1,583 +0,0 @@
|
|
-# ---
|
|
|
|
-# jupyter:
|
|
|
|
-# jupytext:
|
|
|
|
-# text_representation:
|
|
|
|
-# extension: .py
|
|
|
|
-# format_name: light
|
|
|
|
-# format_version: '1.5'
|
|
|
|
-# jupytext_version: 1.16.1
|
|
|
|
-# kernelspec:
|
|
|
|
-# display_name: Python 3
|
|
|
|
-# language: python
|
|
|
|
-# name: python3
|
|
|
|
-# ---
|
|
|
|
-
|
|
|
|
-# Put these at the top of every notebook, to get automatic reloading and inline plotting
|
|
|
|
-# %reload_ext autoreload
|
|
|
|
-# %autoreload 2
|
|
|
|
-# %matplotlib inline
|
|
|
|
-
|
|
|
|
-import sys
|
|
|
|
-sys.path.insert(0, './PaddleOCR/')
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-from paddleocr import PaddleOCR, PPStructure
|
|
|
|
-from pathlib import Path
|
|
|
|
-
|
|
|
|
-# from IPython.core.display import HTML
|
|
|
|
-from IPython.display import display, HTML
|
|
|
|
-from PIL import Image
|
|
|
|
-import re
|
|
|
|
-import cv2
|
|
|
|
-import base64
|
|
|
|
-from io import BytesIO
|
|
|
|
-from PIL import Image, ImageOps
|
|
|
|
-import numpy as np
|
|
|
|
-from matplotlib import pyplot as plt
|
|
|
|
-from sx_utils import *
|
|
|
|
-import json
|
|
|
|
-import threading
|
|
|
|
-from collections import namedtuple
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-def table_res(img_path, ROTATE=-1):
|
|
|
|
- im = cv2.imread(img_path)
|
|
|
|
- if ROTATE >= 0:
|
|
|
|
- im = cv2.rotate(im, ROTATE)
|
|
|
|
- html = table_engine(im)[0]['res']['html']
|
|
|
|
- return html
|
|
|
|
-
|
|
|
|
-def cal_html_to_chs(html):
|
|
|
|
- res = []
|
|
|
|
- rows = re.split('<tr>', html)
|
|
|
|
- for row in rows:
|
|
|
|
- row = re.split('<td>', row)
|
|
|
|
- cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
|
|
|
|
- rec_str = ''.join(cells)
|
|
|
|
- for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
|
|
|
|
- rec_str = rec_str.replace(tag, '')
|
|
|
|
-
|
|
|
|
- res.append(rec_str)
|
|
|
|
-
|
|
|
|
- rec_res = ''.join(res).replace(' ','')
|
|
|
|
- rec_res = re.split('<tdcolspan="\w+">', rec_res)
|
|
|
|
- rec_res = ''.join(rec_res).replace(' ','')
|
|
|
|
- print(rec_res)
|
|
|
|
- return len(rec_res)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# -
|
|
|
|
-
|
|
|
|
-lock = threading.Lock()
|
|
|
|
-def read_annos(path):
|
|
|
|
- with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
- return f.readlines()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-from collections import namedtuple
|
|
|
|
-from typing import List
|
|
|
|
-import Levenshtein
|
|
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
|
-from functools import partial
|
|
|
|
-from tqdm import tqdm
|
|
|
|
-import requests
|
|
|
|
-import shutil
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-from decorator import decorator
|
|
|
|
-@decorator
|
|
|
|
-def rule1_decorator(f, *args,**kwargs):
|
|
|
|
- '''
|
|
|
|
- predict_line = ['项目 ', '', '每100克营养素参考值%', '']
|
|
|
|
- '''
|
|
|
|
- predict_line = args[1]
|
|
|
|
- predict_line = f(*args,**kwargs)
|
|
|
|
- idx = predict_line.index('')
|
|
|
|
- try:
|
|
|
|
- if idx == 1:
|
|
|
|
- if '项目' in predict_line[0] and '每100克' in predict_line[2]:
|
|
|
|
- predict_line[1] = '每100克'
|
|
|
|
- r = re.split('每100克', predict_line[2])
|
|
|
|
- if len(r) == 2 and r[1]:
|
|
|
|
- predict_line[2] = r[1]
|
|
|
|
- except IndexError as e:
|
|
|
|
- print(e)
|
|
|
|
- return predict_line
|
|
|
|
-
|
|
|
|
-@decorator
|
|
|
|
-def rule2_decorator(f, *args,**kwargs):
|
|
|
|
- '''
|
|
|
|
- predict_line = ['碳水化合物18.2克', '', '6%', '']
|
|
|
|
- '''
|
|
|
|
- predict_line = args[1]
|
|
|
|
- predict_line = f(*args,**kwargs)
|
|
|
|
- idx = predict_line.index('')
|
|
|
|
- try:
|
|
|
|
- if idx == 1:
|
|
|
|
- if '化合物' in predict_line[0]:
|
|
|
|
- r = re.split('化合物', predict_line[0])
|
|
|
|
- predict_line[0] = '碳水化合物'
|
|
|
|
- if len(r) == 2 and r[1]:
|
|
|
|
- predict_line[1] = r[1]
|
|
|
|
- except IndexError as e:
|
|
|
|
- print(e)
|
|
|
|
- return predict_line
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-@decorator
|
|
|
|
-def rule3_decorator(f, *args,**kwargs):
|
|
|
|
- '''
|
|
|
|
- ['患直质', '1.6克', '3%', '']
|
|
|
|
- ['脂扇', '1.1', '19%', '']
|
|
|
|
- ['碳水化合物', '勿18.2克', '6%', '']
|
|
|
|
-
|
|
|
|
- '''
|
|
|
|
- predict_line = args[1]
|
|
|
|
- predict_line = f(*args,**kwargs)
|
|
|
|
- predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
|
|
|
|
- predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
|
|
|
|
- predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
|
|
|
|
- predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
|
|
|
|
- return predict_line
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
|
|
|
|
-def get_fn(annos, idx):
|
|
|
|
- return json.loads(annos[idx])['filename']
|
|
|
|
-
|
|
|
|
-def get_gt(annos, idx):
|
|
|
|
- return json.loads(annos[idx])['gt']
|
|
|
|
-
|
|
|
|
-def get_img_path(annos, idx, root_path):
|
|
|
|
- return root_path / get_fn(annos, idx)
|
|
|
|
-
|
|
|
|
-def filter_annos(annos: List[str], fns: List[Path]):
|
|
|
|
- res = []
|
|
|
|
- pattern = '"filename": "(.*.jpg)"'
|
|
|
|
-
|
|
|
|
- for anno in annos:
|
|
|
|
- for fn in fns:
|
|
|
|
- m = re.search(fn.name, anno)
|
|
|
|
- if m:
|
|
|
|
- sub = f'"filename": "{str(fn)}"'
|
|
|
|
- new_anno = re.sub(pattern,sub, anno)
|
|
|
|
- res.append(new_anno)
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def _predict_table2(annos, idx, root_path):
|
|
|
|
- img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
|
- try:
|
|
|
|
- lock.acquire()
|
|
|
|
- img_path = str(img_path)
|
|
|
|
- html = table_engine(img_path)[0]['res']['html']
|
|
|
|
- return html
|
|
|
|
- finally:
|
|
|
|
- lock.release()
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
-def _predict_table1(annos, idx, root_path):
|
|
|
|
- img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
|
- img_path = str(img_path)
|
|
|
|
- img = cv2.imread(img_path)
|
|
|
|
- img_str = cv2bgr_base64(img)
|
|
|
|
- payload = {'image': img_str, 'det': 'conv'}
|
|
|
|
- r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
|
|
|
|
- res = r.json()
|
|
|
|
- if 'status' in res and res['status'] == '000':
|
|
|
|
- return res['result']['html']
|
|
|
|
- else:
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
-def _predict_table3(annos, idx, root_path):
|
|
|
|
- img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
|
- img_path = str(img_path)
|
|
|
|
- img = cv2.imread(img_path)
|
|
|
|
- img_str = cv2bgr_base64(img)
|
|
|
|
- payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
|
|
|
|
- r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
|
|
|
|
- res = r.json()
|
|
|
|
- if 'status' in res and res['status'] == '000':
|
|
|
|
- return res['result']['html']
|
|
|
|
- else:
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
-def predict_table(annos, idx, root_path):
|
|
|
|
- max_retry = 3
|
|
|
|
- def predict_with_retry(*args, retry=0):
|
|
|
|
- if retry > max_retry:
|
|
|
|
- raise RuntimeError("Max retry failed!")
|
|
|
|
- if retry > 0:
|
|
|
|
- print(f"retry = {retry}, idx = {idx}")
|
|
|
|
- try:
|
|
|
|
- return _predict_table3(*args)
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f'request error: {e}')
|
|
|
|
- return predict_with_retry(*args, retry=retry + 1)
|
|
|
|
- return predict_with_retry(annos, idx, root_path)
|
|
|
|
-
|
|
|
|
-class PairLine:
|
|
|
|
- def __init__(self, gt_line, predict_line):
|
|
|
|
- '''
|
|
|
|
- gt_line: ['项目', '每100克', '营养素参考值%', '']
|
|
|
|
- predict_line: ['项目', '', '每100克营养素参考值%', '']
|
|
|
|
- '''
|
|
|
|
- self.gt_line = gt_line
|
|
|
|
- self.predict_line = predict_line
|
|
|
|
- self.result = []
|
|
|
|
- self.cols = 3
|
|
|
|
-
|
|
|
|
- def compare(self):
|
|
|
|
- for i, word in enumerate(self.predict_line):
|
|
|
|
- # print(self.predict_line, self.gt_line)
|
|
|
|
- if not self.gt_line:
|
|
|
|
- for j in range(self.cols):
|
|
|
|
- self.result.append(WordCompare(gt='', predict=word, cls=False))
|
|
|
|
- try:
|
|
|
|
- cls = True if word.strip() == self.gt_line[i].strip() else False
|
|
|
|
- if not word and not self.gt_line[i]:
|
|
|
|
- continue
|
|
|
|
- self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
|
|
|
|
- except IndexError:
|
|
|
|
- self.result.append(WordCompare(gt='', predict=word, cls=False))
|
|
|
|
- for i, word in enumerate(self.gt_line):
|
|
|
|
- if not self.predict_line:
|
|
|
|
- for j in range(self.cols):
|
|
|
|
- self.result.append(WordCompare(gt=word, predict='', cls=False))
|
|
|
|
- if i >= len(self.predict_line):
|
|
|
|
- self.result.append(WordCompare(gt=word, predict='', cls=False))
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def __repr__(self):
|
|
|
|
- return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-class Table:
|
|
|
|
- def __init__(self, fn, gt_html, predict_html):
|
|
|
|
- self.fn = fn
|
|
|
|
- self.gt_html = gt_html
|
|
|
|
- self.predict_html = predict_html
|
|
|
|
- self.format_lines = []
|
|
|
|
- self.pair_lines = []
|
|
|
|
- self.result = self.get_result()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- @classmethod
|
|
|
|
- def from_dict(cls, tt):
|
|
|
|
- gt_html = tt['gt_html']
|
|
|
|
- predict_html = tt['predict_html']
|
|
|
|
- fn = Path(tt['fn'])
|
|
|
|
- t = cls(fn, gt_html, predict_html)
|
|
|
|
- return t
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def to_dict(self):
|
|
|
|
- return {
|
|
|
|
- 'fn': str(self.fn),
|
|
|
|
- 'gt_html': self.gt_html,
|
|
|
|
- 'predict_html': self.predict_html,
|
|
|
|
- 'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
|
|
|
|
- 'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for o in line] for line in self.result]
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- def display_image(self):
|
|
|
|
- im = Image.open(self.fn)
|
|
|
|
- return ImageOps.exif_transpose(im)
|
|
|
|
-
|
|
|
|
- def display_predict_html(self):
|
|
|
|
- return HTML(self.format_predict_html)
|
|
|
|
-
|
|
|
|
- def display_gt_html(self):
|
|
|
|
- return HTML(self.gt_html)
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def format_predict_html(self):
|
|
|
|
- if self.format_lines:
|
|
|
|
- header = '<html><body><table><tbody>'
|
|
|
|
- footer = '</tbody></table></body></html>'
|
|
|
|
- COLS = 3
|
|
|
|
- html = []
|
|
|
|
- for i, line in enumerate(self.format_lines):
|
|
|
|
- html.append('<tr>')
|
|
|
|
- for j in range(COLS):
|
|
|
|
- try:
|
|
|
|
- if i == 0 and '成分表' in line[j]:
|
|
|
|
- html.append('<td colspan="3">')
|
|
|
|
- html.append(line[j])
|
|
|
|
- html.append('</td>')
|
|
|
|
- break;
|
|
|
|
- else:
|
|
|
|
- html.append('<td>')
|
|
|
|
- html.append(line[j])
|
|
|
|
- html.append('</td>')
|
|
|
|
- except IndexError as e:
|
|
|
|
- print('format_predict_html', e)
|
|
|
|
- html.append('<td>')
|
|
|
|
- html.append('')
|
|
|
|
- html.append('</td>')
|
|
|
|
- continue
|
|
|
|
- html.append('</tr>')
|
|
|
|
- res = f'{header}{"".join(html)}{footer}'
|
|
|
|
- return res
|
|
|
|
- else:
|
|
|
|
- return self.predict_html
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def error_rate(self):
|
|
|
|
- corrects = 0
|
|
|
|
- errors = 0
|
|
|
|
- for line in self.result:
|
|
|
|
- for word in line:
|
|
|
|
- if word.cls:
|
|
|
|
- corrects += 1
|
|
|
|
- else:
|
|
|
|
- errors += 1
|
|
|
|
-
|
|
|
|
- total = corrects + errors
|
|
|
|
- return 0 if errors == 1 else errors / total
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def precision(self):
|
|
|
|
- corrects = 0
|
|
|
|
- p_len = 0
|
|
|
|
- for line in self.result:
|
|
|
|
- for word in line:
|
|
|
|
- if word.cls:
|
|
|
|
- corrects += 1
|
|
|
|
- if word.predict:
|
|
|
|
- p_len += 1
|
|
|
|
- return 0 if p_len == 0 else corrects / p_len
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def recall(self):
|
|
|
|
- corrects = 0
|
|
|
|
- g_len = 0
|
|
|
|
- for line in self.result:
|
|
|
|
- for word in line:
|
|
|
|
- if word.cls:
|
|
|
|
- corrects += 1
|
|
|
|
- if word.gt:
|
|
|
|
- g_len += 1
|
|
|
|
- return 0 if g_len == 0 else corrects / g_len
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def hmean(self):
|
|
|
|
- total = self.recall + self.precision
|
|
|
|
- return 0 if total == 0 else 2 * self.precision * self.recall / total
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def get_result(self):
|
|
|
|
- res = []
|
|
|
|
- self._generate_pair_lines()
|
|
|
|
- # print(self.pair_lines)
|
|
|
|
- for pair_line in self.pair_lines:
|
|
|
|
- pair_line.compare()
|
|
|
|
- res.append(pair_line.result)
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- @rule3_decorator
|
|
|
|
- @rule2_decorator
|
|
|
|
- @rule1_decorator
|
|
|
|
- def _format_predict_line(self, predict_line):
|
|
|
|
- return predict_line
|
|
|
|
-
|
|
|
|
- def _get_lines(self, html) -> List[str]:
|
|
|
|
- '''
|
|
|
|
- res: ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
|
|
|
|
- '''
|
|
|
|
- if not html:
|
|
|
|
- return []
|
|
|
|
- rows = re.split('<tr>', html)
|
|
|
|
- res = []
|
|
|
|
- for row in rows:
|
|
|
|
- m = re.findall('<td.*>.*</td>', row)
|
|
|
|
- if m:
|
|
|
|
- res.extend(m)
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def _generate_pair_lines(self):
|
|
|
|
- gt_lines = self._get_lines(self.gt_html)
|
|
|
|
- predict_lines = self._get_lines(self.predict_html)
|
|
|
|
- gt_words_list = [self._split_to_words(line) for line in gt_lines]
|
|
|
|
- predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
|
|
|
|
- self.format_lines.extend(predict_words_list)
|
|
|
|
-
|
|
|
|
- DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
|
|
|
|
- dist_entries = []
|
|
|
|
- p = [False] * len(predict_words_list)
|
|
|
|
- g = [False] * len(gt_words_list)
|
|
|
|
- for i, p_line in enumerate(predict_words_list):
|
|
|
|
- for j, g_line in enumerate(gt_words_list):
|
|
|
|
- dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
|
|
|
|
- dist_entries.append(DistEntry(i=i, j=j, dist=dist))
|
|
|
|
- dist_entries.sort(key=lambda e: e.dist)
|
|
|
|
- for e in dist_entries:
|
|
|
|
- if not p[e.i] and not g[e.j]:
|
|
|
|
- p[e.i] = True
|
|
|
|
- g[e.j] = True
|
|
|
|
- self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
|
|
|
|
- for i in range(len(p)):
|
|
|
|
- if not p[i]:
|
|
|
|
- self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
|
|
|
|
- for i in range(len(g)):
|
|
|
|
- if not g[i]:
|
|
|
|
- self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def _match_gt_line(self, line, gt_lines):
|
|
|
|
- line = ''.join(line)
|
|
|
|
- min_dist = 9999
|
|
|
|
- res = []
|
|
|
|
- for i, gt_line in enumerate(gt_lines):
|
|
|
|
- gt_line = ''.join(gt_line)
|
|
|
|
- dist = Levenshtein.distance(gt_line, line)
|
|
|
|
- if dist < min_dist:
|
|
|
|
- min_dist = dist
|
|
|
|
- res = gt_lines[i]
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def _split_to_words(self, line):
|
|
|
|
- '''
|
|
|
|
- line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
|
|
|
|
- res: ['项目', '每100克', '营养素参考值%', '']
|
|
|
|
- '''
|
|
|
|
- res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-def generate_tables(annos, root_path, i):
|
|
|
|
- predict_html = predict_table(annos, i, root_path=root_path)
|
|
|
|
- gt_html = get_gt(annos, i)
|
|
|
|
- fn = get_img_path(annos, i, root_path=root_path)
|
|
|
|
- table = Table(fn, gt_html=gt_html, predict_html=predict_html)
|
|
|
|
- return table
|
|
|
|
-
|
|
|
|
-class TableDataset:
|
|
|
|
- def __init__(self, root_path=None, anno_fn=None):
|
|
|
|
- if root_path and anno_fn:
|
|
|
|
- self.tables = []
|
|
|
|
- annos = read_annos(anno_fn)
|
|
|
|
- l = len(annos)
|
|
|
|
- # l = 10
|
|
|
|
- with ThreadPoolExecutor(max_workers=10) as executor:
|
|
|
|
- tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
|
|
|
|
- for table in tables:
|
|
|
|
- self.tables.append(table)
|
|
|
|
- else:
|
|
|
|
- self.tables = []
|
|
|
|
- @property
|
|
|
|
- def correct_num(self):
|
|
|
|
- return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def avg_error_rate(self):
|
|
|
|
- return np.mean([o.error_rate for o in self.tables])
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def avg_precision(self):
|
|
|
|
- return np.mean([o.precision for o in self.tables])
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def avg_recall(self):
|
|
|
|
- return np.mean([o.recall for o in self.tables])
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def avg_hmean(self):
|
|
|
|
- return np.mean([o.hmean for o in self.tables])
|
|
|
|
-
|
|
|
|
- def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
|
|
|
|
- tables = self.top_error_tables_by_threshold(th)
|
|
|
|
- fns = [t.fn for t in tables]
|
|
|
|
- if not (root_path / dst_path).exists():
|
|
|
|
- (root_path / dst_path).mkdir()
|
|
|
|
- for fn in tqdm(fns):
|
|
|
|
- src = fn
|
|
|
|
- dst = root_path / dst_path / fn.name
|
|
|
|
- shutil.copy2(src, dst)
|
|
|
|
- fns = [dst_path / t.fn.name for t in tables]
|
|
|
|
- annos = read_annos(anno_fn)
|
|
|
|
- annos = filter_annos(annos, fns)
|
|
|
|
- write_annos(annos, root_path / dst_path / 'gt.txt')
|
|
|
|
-
|
|
|
|
- def top_error_tables_by_threshold(self, th):
|
|
|
|
- res = []
|
|
|
|
- for r in self.tables:
|
|
|
|
- if r.error_rate >= th:
|
|
|
|
- res.append(r)
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- def top_error_tables(self, k):
|
|
|
|
- tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
|
|
|
|
- return tables[:k]
|
|
|
|
-
|
|
|
|
- def to_json(self, fn):
|
|
|
|
- res = [o.to_dict() for o in self.tables]
|
|
|
|
- with open(fn, 'w', encoding='utf-8') as f:
|
|
|
|
- json.dump(res, f)
|
|
|
|
-
|
|
|
|
- @classmethod
|
|
|
|
- def from_json(cls, fn):
|
|
|
|
- res = cls()
|
|
|
|
- with open(fn, 'r') as f:
|
|
|
|
- ts = json.load(f)
|
|
|
|
- for o in ts:
|
|
|
|
- res.tables.append(Table.from_dict(o))
|
|
|
|
- return res
|
|
|
|
-
|
|
|
|
-# -
|
|
|
|
-
|
|
|
|
-root_path = Path('table-dataset')
|
|
|
|
-anno_fn = root_path / 'merge.txt'
|
|
|
|
-# anno_fn = root_path / 'hardimgs/gt.txt'
|
|
|
|
-
|
|
|
|
-# + jupyter={"outputs_hidden": true}
|
|
|
|
-import time
|
|
|
|
-cur_time = int(time.time())
|
|
|
|
-table_ds = TableDataset(root_path, anno_fn)
|
|
|
|
-json_filename = f'table_merge_{cur_time}.json'
|
|
|
|
-table_ds.to_json(json_filename)
|
|
|
|
-print(json_filename)
|
|
|
|
-
|
|
|
|
-# +
|
|
|
|
-# table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
|
|
|
|
-# -
|
|
|
|
-
|
|
|
|
-table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
|
|
|
|
-
|
|
|
|
-k = 100
|
|
|
|
-errors = table_ds.top_error_tables(k)
|
|
|
|
-
|
|
|
|
-tt = errors[9]
|
|
|
|
-
|
|
|
|
-for i, tt in enumerate(errors[:30]):
|
|
|
|
- print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
|
|
|
|
- display(tt.display_image())
|
|
|
|
- display(tt.display_predict_html())
|
|
|
|
- display(tt.display_gt_html())
|
|
|
|
-
|
|
|
|
-tt.display_image()
|
|
|
|
-
|
|
|
|
-tt.display_predict_html()
|
|
|
|
-
|
|
|
|
-tt.display_gt_html()
|
|
|
|
-
|
|
|
|
-tt.predict_html
|
|
|
|
-
|
|
|
|
-tt.gt_html
|
|
|
|
-
|
|
|
|
-tt.fn
|
|
|
|
-
|
|
|
|
-len(table_ds.top_error_tables_by_threshold(0.1))
|
|
|
|
-
|
|
|
|
-table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))
|
|
|
|
-
|
|
|
|
-
|
|
|