|
@@ -0,0 +1,583 @@
|
|
|
+# ---
|
|
|
+# jupyter:
|
|
|
+# jupytext:
|
|
|
+# text_representation:
|
|
|
+# extension: .py
|
|
|
+# format_name: light
|
|
|
+# format_version: '1.5'
|
|
|
+# jupytext_version: 1.16.1
|
|
|
+# kernelspec:
|
|
|
+# display_name: Python 3
|
|
|
+# language: python
|
|
|
+# name: python3
|
|
|
+# ---
|
|
|
+
|
|
|
+# Put these at the top of every notebook, to get automatic reloading and inline plotting
|
|
|
+# %reload_ext autoreload
|
|
|
+# %autoreload 2
|
|
|
+# %matplotlib inline
|
|
|
+
|
|
|
+import sys
|
|
|
+sys.path.insert(0, './PaddleOCR/')
|
|
|
+
|
|
|
+# +
|
|
|
+from paddleocr import PaddleOCR, PPStructure
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+# from IPython.core.display import HTML
|
|
|
+from IPython.display import display, HTML
|
|
|
+from PIL import Image
|
|
|
+import re
|
|
|
+import cv2
|
|
|
+import base64
|
|
|
+from io import BytesIO
|
|
|
+from PIL import Image, ImageOps
|
|
|
+import numpy as np
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from sx_utils import *
|
|
|
+import json
|
|
|
+import threading
|
|
|
+from collections import namedtuple
|
|
|
+
|
|
|
+
|
|
|
+# +
|
|
|
+def table_res(img_path, ROTATE=-1):
|
|
|
+ im = cv2.imread(img_path)
|
|
|
+ if ROTATE >= 0:
|
|
|
+ im = cv2.rotate(im, ROTATE)
|
|
|
+ html = table_engine(im)[0]['res']['html']
|
|
|
+ return html
|
|
|
+
|
|
|
+def cal_html_to_chs(html):
|
|
|
+ res = []
|
|
|
+ rows = re.split('<tr>', html)
|
|
|
+ for row in rows:
|
|
|
+ row = re.split('<td>', row)
|
|
|
+ cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
|
|
|
+ rec_str = ''.join(cells)
|
|
|
+ for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
|
|
|
+ rec_str = rec_str.replace(tag, '')
|
|
|
+
|
|
|
+ res.append(rec_str)
|
|
|
+
|
|
|
+ rec_res = ''.join(res).replace(' ','')
|
|
|
+ rec_res = re.split('<tdcolspan="\w+">', rec_res)
|
|
|
+ rec_res = ''.join(rec_res).replace(' ','')
|
|
|
+ print(rec_res)
|
|
|
+ return len(rec_res)
|
|
|
+
|
|
|
+
|
|
|
+# -
|
|
|
+
|
|
|
+lock = threading.Lock()
|
|
|
+def read_annos(path):
|
|
|
+ with open(path, 'r', encoding='utf-8') as f:
|
|
|
+ return f.readlines()
|
|
|
+
|
|
|
+
|
|
|
+from collections import namedtuple
|
|
|
+from typing import List
|
|
|
+import Levenshtein
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+from functools import partial
|
|
|
+from tqdm import tqdm
|
|
|
+import requests
|
|
|
+import shutil
|
|
|
+
|
|
|
+# +
|
|
|
+from decorator import decorator
|
|
|
+@decorator
|
|
|
+def rule1_decorator(f, *args,**kwargs):
|
|
|
+ '''
|
|
|
+ predict_line = ['项目 ', '', '每100克营养素参考值%', '']
|
|
|
+ '''
|
|
|
+ predict_line = args[1]
|
|
|
+ predict_line = f(*args,**kwargs)
|
|
|
+ idx = predict_line.index('')
|
|
|
+ try:
|
|
|
+ if idx == 1:
|
|
|
+ if '项目' in predict_line[0] and '每100克' in predict_line[2]:
|
|
|
+ predict_line[1] = '每100克'
|
|
|
+ r = re.split('每100克', predict_line[2])
|
|
|
+ if len(r) == 2 and r[1]:
|
|
|
+ predict_line[2] = r[1]
|
|
|
+ except IndexError as e:
|
|
|
+ print(e)
|
|
|
+ return predict_line
|
|
|
+
|
|
|
+@decorator
|
|
|
+def rule2_decorator(f, *args,**kwargs):
|
|
|
+ '''
|
|
|
+ predict_line = ['碳水化合物18.2克', '', '6%', '']
|
|
|
+ '''
|
|
|
+ predict_line = args[1]
|
|
|
+ predict_line = f(*args,**kwargs)
|
|
|
+ idx = predict_line.index('')
|
|
|
+ try:
|
|
|
+ if idx == 1:
|
|
|
+ if '化合物' in predict_line[0]:
|
|
|
+ r = re.split('化合物', predict_line[0])
|
|
|
+ predict_line[0] = '碳水化合物'
|
|
|
+ if len(r) == 2 and r[1]:
|
|
|
+ predict_line[1] = r[1]
|
|
|
+ except IndexError as e:
|
|
|
+ print(e)
|
|
|
+ return predict_line
|
|
|
+
|
|
|
+
|
|
|
+@decorator
|
|
|
+def rule3_decorator(f, *args,**kwargs):
|
|
|
+ '''
|
|
|
+ ['患直质', '1.6克', '3%', '']
|
|
|
+ ['脂扇', '1.1', '19%', '']
|
|
|
+ ['碳水化合物', '勿18.2克', '6%', '']
|
|
|
+
|
|
|
+ '''
|
|
|
+ predict_line = args[1]
|
|
|
+ predict_line = f(*args,**kwargs)
|
|
|
+ predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
|
|
|
+ predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
|
|
|
+ predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
|
|
|
+ predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
|
|
|
+ return predict_line
|
|
|
+
|
|
|
+
|
|
|
+# +
|
|
|
+WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
|
|
|
+def get_fn(annos, idx):
|
|
|
+ return json.loads(annos[idx])['filename']
|
|
|
+
|
|
|
+def get_gt(annos, idx):
|
|
|
+ return json.loads(annos[idx])['gt']
|
|
|
+
|
|
|
+def get_img_path(annos, idx, root_path):
|
|
|
+ return root_path / get_fn(annos, idx)
|
|
|
+
|
|
|
+def filter_annos(annos: List[str], fns: List[Path]):
|
|
|
+ res = []
|
|
|
+ pattern = '"filename": "(.*.jpg)"'
|
|
|
+
|
|
|
+ for anno in annos:
|
|
|
+ for fn in fns:
|
|
|
+ m = re.search(fn.name, anno)
|
|
|
+ if m:
|
|
|
+ sub = f'"filename": "{str(fn)}"'
|
|
|
+ new_anno = re.sub(pattern,sub, anno)
|
|
|
+ res.append(new_anno)
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def _predict_table2(annos, idx, root_path):
|
|
|
+ img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
+ try:
|
|
|
+ lock.acquire()
|
|
|
+ img_path = str(img_path)
|
|
|
+ html = table_engine(img_path)[0]['res']['html']
|
|
|
+ return html
|
|
|
+ finally:
|
|
|
+ lock.release()
|
|
|
+ return None
|
|
|
+
|
|
|
+def _predict_table1(annos, idx, root_path):
|
|
|
+ img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
+ img_path = str(img_path)
|
|
|
+ img = cv2.imread(img_path)
|
|
|
+ img_str = cv2bgr_base64(img)
|
|
|
+ payload = {'image': img_str, 'det': 'conv'}
|
|
|
+ r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
|
|
|
+ res = r.json()
|
|
|
+ if 'status' in res and res['status'] == '000':
|
|
|
+ return res['result']['html']
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+def _predict_table3(annos, idx, root_path):
|
|
|
+ img_path = get_img_path(annos, idx, root_path=root_path)
|
|
|
+ img_path = str(img_path)
|
|
|
+ img = cv2.imread(img_path)
|
|
|
+ img_str = cv2bgr_base64(img)
|
|
|
+ payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
|
|
|
+ r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
|
|
|
+ res = r.json()
|
|
|
+ if 'status' in res and res['status'] == '000':
|
|
|
+ return res['result']['html']
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+def predict_table(annos, idx, root_path):
|
|
|
+ max_retry = 3
|
|
|
+ def predict_with_retry(*args, retry=0):
|
|
|
+ if retry > max_retry:
|
|
|
+ raise RuntimeError("Max retry failed!")
|
|
|
+ if retry > 0:
|
|
|
+ print(f"retry = {retry}, idx = {idx}")
|
|
|
+ try:
|
|
|
+ return _predict_table3(*args)
|
|
|
+ except Exception as e:
|
|
|
+ print(f'request error: {e}')
|
|
|
+ return predict_with_retry(*args, retry=retry + 1)
|
|
|
+ return predict_with_retry(annos, idx, root_path)
|
|
|
+
|
|
|
+class PairLine:
|
|
|
+ def __init__(self, gt_line, predict_line):
|
|
|
+ '''
|
|
|
+ gt_line: ['项目', '每100克', '营养素参考值%', '']
|
|
|
+ predict_line: ['项目', '', '每100克营养素参考值%', '']
|
|
|
+ '''
|
|
|
+ self.gt_line = gt_line
|
|
|
+ self.predict_line = predict_line
|
|
|
+ self.result = []
|
|
|
+ self.cols = 3
|
|
|
+
|
|
|
+ def compare(self):
|
|
|
+ for i, word in enumerate(self.predict_line):
|
|
|
+ # print(self.predict_line, self.gt_line)
|
|
|
+ if not self.gt_line:
|
|
|
+ for j in range(self.cols):
|
|
|
+ self.result.append(WordCompare(gt='', predict=word, cls=False))
|
|
|
+ try:
|
|
|
+ cls = True if word.strip() == self.gt_line[i].strip() else False
|
|
|
+ if not word and not self.gt_line[i]:
|
|
|
+ continue
|
|
|
+ self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
|
|
|
+ except IndexError:
|
|
|
+ self.result.append(WordCompare(gt='', predict=word, cls=False))
|
|
|
+ for i, word in enumerate(self.gt_line):
|
|
|
+ if not self.predict_line:
|
|
|
+ for j in range(self.cols):
|
|
|
+ self.result.append(WordCompare(gt=word, predict='', cls=False))
|
|
|
+ if i >= len(self.predict_line):
|
|
|
+ self.result.append(WordCompare(gt=word, predict='', cls=False))
|
|
|
+
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
|
|
|
+
|
|
|
+
|
|
|
+# +
|
|
|
+class Table:
|
|
|
+ def __init__(self, fn, gt_html, predict_html):
|
|
|
+ self.fn = fn
|
|
|
+ self.gt_html = gt_html
|
|
|
+ self.predict_html = predict_html
|
|
|
+ self.format_lines = []
|
|
|
+ self.pair_lines = []
|
|
|
+ self.result = self.get_result()
|
|
|
+
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, tt):
|
|
|
+ gt_html = tt['gt_html']
|
|
|
+ predict_html = tt['predict_html']
|
|
|
+ fn = Path(tt['fn'])
|
|
|
+ t = cls(fn, gt_html, predict_html)
|
|
|
+ return t
|
|
|
+
|
|
|
+
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ 'fn': str(self.fn),
|
|
|
+ 'gt_html': self.gt_html,
|
|
|
+ 'predict_html': self.predict_html,
|
|
|
+ 'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
|
|
|
+ 'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for o in line] for line in self.result]
|
|
|
+ }
|
|
|
+
|
|
|
+ def display_image(self):
|
|
|
+ im = Image.open(self.fn)
|
|
|
+ return ImageOps.exif_transpose(im)
|
|
|
+
|
|
|
+ def display_predict_html(self):
|
|
|
+ return HTML(self.format_predict_html)
|
|
|
+
|
|
|
+ def display_gt_html(self):
|
|
|
+ return HTML(self.gt_html)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def format_predict_html(self):
|
|
|
+ if self.format_lines:
|
|
|
+ header = '<html><body><table><tbody>'
|
|
|
+ footer = '</tbody></table></body></html>'
|
|
|
+ COLS = 3
|
|
|
+ html = []
|
|
|
+ for i, line in enumerate(self.format_lines):
|
|
|
+ html.append('<tr>')
|
|
|
+ for j in range(COLS):
|
|
|
+ try:
|
|
|
+ if i == 0 and '成分表' in line[j]:
|
|
|
+ html.append('<td colspan="3">')
|
|
|
+ html.append(line[j])
|
|
|
+ html.append('</td>')
|
|
|
+ break;
|
|
|
+ else:
|
|
|
+ html.append('<td>')
|
|
|
+ html.append(line[j])
|
|
|
+ html.append('</td>')
|
|
|
+ except IndexError as e:
|
|
|
+ print('format_predict_html', e)
|
|
|
+ html.append('<td>')
|
|
|
+ html.append('')
|
|
|
+ html.append('</td>')
|
|
|
+ continue
|
|
|
+ html.append('</tr>')
|
|
|
+ res = f'{header}{"".join(html)}{footer}'
|
|
|
+ return res
|
|
|
+ else:
|
|
|
+ return self.predict_html
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ @property
|
|
|
+ def error_rate(self):
|
|
|
+ corrects = 0
|
|
|
+ errors = 0
|
|
|
+ for line in self.result:
|
|
|
+ for word in line:
|
|
|
+ if word.cls:
|
|
|
+ corrects += 1
|
|
|
+ else:
|
|
|
+ errors += 1
|
|
|
+
|
|
|
+ total = corrects + errors
|
|
|
+ return 0 if errors == 1 else errors / total
|
|
|
+
|
|
|
+ @property
|
|
|
+ def precision(self):
|
|
|
+ corrects = 0
|
|
|
+ p_len = 0
|
|
|
+ for line in self.result:
|
|
|
+ for word in line:
|
|
|
+ if word.cls:
|
|
|
+ corrects += 1
|
|
|
+ if word.predict:
|
|
|
+ p_len += 1
|
|
|
+ return 0 if p_len == 0 else corrects / p_len
|
|
|
+
|
|
|
+ @property
|
|
|
+ def recall(self):
|
|
|
+ corrects = 0
|
|
|
+ g_len = 0
|
|
|
+ for line in self.result:
|
|
|
+ for word in line:
|
|
|
+ if word.cls:
|
|
|
+ corrects += 1
|
|
|
+ if word.gt:
|
|
|
+ g_len += 1
|
|
|
+ return 0 if g_len == 0 else corrects / g_len
|
|
|
+
|
|
|
+ @property
|
|
|
+ def hmean(self):
|
|
|
+ total = self.recall + self.precision
|
|
|
+ return 0 if total == 0 else 2 * self.precision * self.recall / total
|
|
|
+
|
|
|
+
|
|
|
+ def get_result(self):
|
|
|
+ res = []
|
|
|
+ self._generate_pair_lines()
|
|
|
+ # print(self.pair_lines)
|
|
|
+ for pair_line in self.pair_lines:
|
|
|
+ pair_line.compare()
|
|
|
+ res.append(pair_line.result)
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+ @rule3_decorator
|
|
|
+ @rule2_decorator
|
|
|
+ @rule1_decorator
|
|
|
+ def _format_predict_line(self, predict_line):
|
|
|
+ return predict_line
|
|
|
+
|
|
|
+ def _get_lines(self, html) -> List[str]:
|
|
|
+ '''
|
|
|
+ res: ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
|
|
|
+ '''
|
|
|
+ if not html:
|
|
|
+ return []
|
|
|
+ rows = re.split('<tr>', html)
|
|
|
+ res = []
|
|
|
+ for row in rows:
|
|
|
+ m = re.findall('<td.*>.*</td>', row)
|
|
|
+ if m:
|
|
|
+ res.extend(m)
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+ def _generate_pair_lines(self):
|
|
|
+ gt_lines = self._get_lines(self.gt_html)
|
|
|
+ predict_lines = self._get_lines(self.predict_html)
|
|
|
+ gt_words_list = [self._split_to_words(line) for line in gt_lines]
|
|
|
+ predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
|
|
|
+ self.format_lines.extend(predict_words_list)
|
|
|
+
|
|
|
+ DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
|
|
|
+ dist_entries = []
|
|
|
+ p = [False] * len(predict_words_list)
|
|
|
+ g = [False] * len(gt_words_list)
|
|
|
+ for i, p_line in enumerate(predict_words_list):
|
|
|
+ for j, g_line in enumerate(gt_words_list):
|
|
|
+ dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
|
|
|
+ dist_entries.append(DistEntry(i=i, j=j, dist=dist))
|
|
|
+ dist_entries.sort(key=lambda e: e.dist)
|
|
|
+ for e in dist_entries:
|
|
|
+ if not p[e.i] and not g[e.j]:
|
|
|
+ p[e.i] = True
|
|
|
+ g[e.j] = True
|
|
|
+ self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
|
|
|
+ for i in range(len(p)):
|
|
|
+ if not p[i]:
|
|
|
+ self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
|
|
|
+ for i in range(len(g)):
|
|
|
+ if not g[i]:
|
|
|
+ self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
|
|
|
+
|
|
|
+
|
|
|
+ def _match_gt_line(self, line, gt_lines):
|
|
|
+ line = ''.join(line)
|
|
|
+ min_dist = 9999
|
|
|
+ res = []
|
|
|
+ for i, gt_line in enumerate(gt_lines):
|
|
|
+ gt_line = ''.join(gt_line)
|
|
|
+ dist = Levenshtein.distance(gt_line, line)
|
|
|
+ if dist < min_dist:
|
|
|
+ min_dist = dist
|
|
|
+ res = gt_lines[i]
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+ def _split_to_words(self, line):
|
|
|
+ '''
|
|
|
+ line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
|
|
|
+ res: ['项目', '每100克', '营养素参考值%', '']
|
|
|
+ '''
|
|
|
+ res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
|
|
|
+ return res
|
|
|
+
|
|
|
+def generate_tables(annos, root_path, i):
|
|
|
+ predict_html = predict_table(annos, i, root_path=root_path)
|
|
|
+ gt_html = get_gt(annos, i)
|
|
|
+ fn = get_img_path(annos, i, root_path=root_path)
|
|
|
+ table = Table(fn, gt_html=gt_html, predict_html=predict_html)
|
|
|
+ return table
|
|
|
+
|
|
|
+class TableDataset:
|
|
|
+ def __init__(self, root_path=None, anno_fn=None):
|
|
|
+ if root_path and anno_fn:
|
|
|
+ self.tables = []
|
|
|
+ annos = read_annos(anno_fn)
|
|
|
+ l = len(annos)
|
|
|
+ # l = 10
|
|
|
+ with ThreadPoolExecutor(max_workers=10) as executor:
|
|
|
+ tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
|
|
|
+ for table in tables:
|
|
|
+ self.tables.append(table)
|
|
|
+ else:
|
|
|
+ self.tables = []
|
|
|
+ @property
|
|
|
+ def correct_num(self):
|
|
|
+ return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def avg_error_rate(self):
|
|
|
+ return np.mean([o.error_rate for o in self.tables])
|
|
|
+
|
|
|
+ @property
|
|
|
+ def avg_precision(self):
|
|
|
+ return np.mean([o.precision for o in self.tables])
|
|
|
+
|
|
|
+ @property
|
|
|
+ def avg_recall(self):
|
|
|
+ return np.mean([o.recall for o in self.tables])
|
|
|
+
|
|
|
+ @property
|
|
|
+ def avg_hmean(self):
|
|
|
+ return np.mean([o.hmean for o in self.tables])
|
|
|
+
|
|
|
+ def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
|
|
|
+ tables = self.top_error_tables_by_threshold(th)
|
|
|
+ fns = [t.fn for t in tables]
|
|
|
+ if not (root_path / dst_path).exists():
|
|
|
+ (root_path / dst_path).mkdir()
|
|
|
+ for fn in tqdm(fns):
|
|
|
+ src = fn
|
|
|
+ dst = root_path / dst_path / fn.name
|
|
|
+ shutil.copy2(src, dst)
|
|
|
+ fns = [dst_path / t.fn.name for t in tables]
|
|
|
+ annos = read_annos(anno_fn)
|
|
|
+ annos = filter_annos(annos, fns)
|
|
|
+ write_annos(annos, root_path / dst_path / 'gt.txt')
|
|
|
+
|
|
|
+ def top_error_tables_by_threshold(self, th):
|
|
|
+ res = []
|
|
|
+ for r in self.tables:
|
|
|
+ if r.error_rate >= th:
|
|
|
+ res.append(r)
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+ def top_error_tables(self, k):
|
|
|
+ tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
|
|
|
+ return tables[:k]
|
|
|
+
|
|
|
+ def to_json(self, fn):
|
|
|
+ res = [o.to_dict() for o in self.tables]
|
|
|
+ with open(fn, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(res, f)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_json(cls, fn):
|
|
|
+ res = cls()
|
|
|
+ with open(fn, 'r') as f:
|
|
|
+ ts = json.load(f)
|
|
|
+ for o in ts:
|
|
|
+ res.tables.append(Table.from_dict(o))
|
|
|
+ return res
|
|
|
+
|
|
|
+# -
|
|
|
+
|
|
|
+root_path = Path('table-dataset')
|
|
|
+anno_fn = root_path / 'merge.txt'
|
|
|
+# anno_fn = root_path / 'hardimgs/gt.txt'
|
|
|
+
|
|
|
+# + jupyter={"outputs_hidden": true}
|
|
|
+import time
|
|
|
+cur_time = int(time.time())
|
|
|
+table_ds = TableDataset(root_path, anno_fn)
|
|
|
+json_filename = f'table_merge_{cur_time}.json'
|
|
|
+table_ds.to_json(json_filename)
|
|
|
+print(json_filename)
|
|
|
+
|
|
|
+# +
|
|
|
+# table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
|
|
|
+# -
|
|
|
+
|
|
|
+table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
|
|
|
+
|
|
|
+k = 100
|
|
|
+errors = table_ds.top_error_tables(k)
|
|
|
+
|
|
|
+tt = errors[9]
|
|
|
+
|
|
|
+for i, tt in enumerate(errors[:30]):
|
|
|
+ print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
|
|
|
+ display(tt.display_image())
|
|
|
+ display(tt.display_predict_html())
|
|
|
+ display(tt.display_gt_html())
|
|
|
+
|
|
|
+tt.display_image()
|
|
|
+
|
|
|
+tt.display_predict_html()
|
|
|
+
|
|
|
+tt.display_gt_html()
|
|
|
+
|
|
|
+tt.predict_html
|
|
|
+
|
|
|
+tt.gt_html
|
|
|
+
|
|
|
+tt.fn
|
|
|
+
|
|
|
+len(table_ds.top_error_tables_by_threshold(0.1))
|
|
|
+
|
|
|
+table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))
|
|
|
+
|
|
|
+
|