123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583 |
- # ---
- # jupyter:
- # jupytext:
- # text_representation:
- # extension: .py
- # format_name: light
- # format_version: '1.5'
- # jupytext_version: 1.16.1
- # kernelspec:
- # display_name: Python 3
- # language: python
- # name: python3
- # ---
- # Put these at the top of every notebook, to get automatic reloading and inline plotting
- # %reload_ext autoreload
- # %autoreload 2
- # %matplotlib inline
- import sys
- sys.path.insert(0, './PaddleOCR/')
- # +
- from paddleocr import PaddleOCR, PPStructure
- from pathlib import Path
- # from IPython.core.display import HTML
- from IPython.display import display, HTML
- from PIL import Image
- import re
- import cv2
- import base64
- from io import BytesIO
- from PIL import Image, ImageOps
- import numpy as np
- from matplotlib import pyplot as plt
- from sx_utils import *
- import json
- import threading
- from collections import namedtuple
- # +
- def table_res(img_path, ROTATE=-1):
- im = cv2.imread(img_path)
- if ROTATE >= 0:
- im = cv2.rotate(im, ROTATE)
- html = table_engine(im)[0]['res']['html']
- return html
- def cal_html_to_chs(html):
- res = []
- rows = re.split('<tr>', html)
- for row in rows:
- row = re.split('<td>', row)
- cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
- rec_str = ''.join(cells)
- for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
- rec_str = rec_str.replace(tag, '')
-
- res.append(rec_str)
-
- rec_res = ''.join(res).replace(' ','')
- rec_res = re.split('<tdcolspan="\w+">', rec_res)
- rec_res = ''.join(rec_res).replace(' ','')
- print(rec_res)
- return len(rec_res)
- # -
- lock = threading.Lock()
- def read_annos(path):
- with open(path, 'r', encoding='utf-8') as f:
- return f.readlines()
- from collections import namedtuple
- from typing import List
- import Levenshtein
- from concurrent.futures import ThreadPoolExecutor
- from functools import partial
- from tqdm import tqdm
- import requests
- import shutil
- # +
- from decorator import decorator
- @decorator
- def rule1_decorator(f, *args,**kwargs):
- '''
- predict_line = ['项目 ', '', '每100克营养素参考值%', '']
- '''
- predict_line = args[1]
- predict_line = f(*args,**kwargs)
- idx = predict_line.index('')
- try:
- if idx == 1:
- if '项目' in predict_line[0] and '每100克' in predict_line[2]:
- predict_line[1] = '每100克'
- r = re.split('每100克', predict_line[2])
- if len(r) == 2 and r[1]:
- predict_line[2] = r[1]
- except IndexError as e:
- print(e)
- return predict_line
- @decorator
- def rule2_decorator(f, *args,**kwargs):
- '''
- predict_line = ['碳水化合物18.2克', '', '6%', '']
- '''
- predict_line = args[1]
- predict_line = f(*args,**kwargs)
- idx = predict_line.index('')
- try:
- if idx == 1:
- if '化合物' in predict_line[0]:
- r = re.split('化合物', predict_line[0])
- predict_line[0] = '碳水化合物'
- if len(r) == 2 and r[1]:
- predict_line[1] = r[1]
- except IndexError as e:
- print(e)
- return predict_line
-
-
- @decorator
- def rule3_decorator(f, *args,**kwargs):
- '''
- ['患直质', '1.6克', '3%', '']
- ['脂扇', '1.1', '19%', '']
- ['碳水化合物', '勿18.2克', '6%', '']
-
- '''
- predict_line = args[1]
- predict_line = f(*args,**kwargs)
- predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
- predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
- predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
- predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
- return predict_line
- # +
- WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
- def get_fn(annos, idx):
- return json.loads(annos[idx])['filename']
- def get_gt(annos, idx):
- return json.loads(annos[idx])['gt']
- def get_img_path(annos, idx, root_path):
- return root_path / get_fn(annos, idx)
- def filter_annos(annos: List[str], fns: List[Path]):
- res = []
- pattern = '"filename": "(.*.jpg)"'
-
- for anno in annos:
- for fn in fns:
- m = re.search(fn.name, anno)
- if m:
- sub = f'"filename": "{str(fn)}"'
- new_anno = re.sub(pattern,sub, anno)
- res.append(new_anno)
- return res
-
-
- def _predict_table2(annos, idx, root_path):
- img_path = get_img_path(annos, idx, root_path=root_path)
- try:
- lock.acquire()
- img_path = str(img_path)
- html = table_engine(img_path)[0]['res']['html']
- return html
- finally:
- lock.release()
- return None
- def _predict_table1(annos, idx, root_path):
- img_path = get_img_path(annos, idx, root_path=root_path)
- img_path = str(img_path)
- img = cv2.imread(img_path)
- img_str = cv2bgr_base64(img)
- payload = {'image': img_str, 'det': 'conv'}
- r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
- res = r.json()
- if 'status' in res and res['status'] == '000':
- return res['result']['html']
- else:
- return None
-
- def _predict_table3(annos, idx, root_path):
- img_path = get_img_path(annos, idx, root_path=root_path)
- img_path = str(img_path)
- img = cv2.imread(img_path)
- img_str = cv2bgr_base64(img)
- payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
- r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
- res = r.json()
- if 'status' in res and res['status'] == '000':
- return res['result']['html']
- else:
- return None
- def predict_table(annos, idx, root_path):
- max_retry = 3
- def predict_with_retry(*args, retry=0):
- if retry > max_retry:
- raise RuntimeError("Max retry failed!")
- if retry > 0:
- print(f"retry = {retry}, idx = {idx}")
- try:
- return _predict_table3(*args)
- except Exception as e:
- print(f'request error: {e}')
- return predict_with_retry(*args, retry=retry + 1)
- return predict_with_retry(annos, idx, root_path)
-
- class PairLine:
- def __init__(self, gt_line, predict_line):
- '''
- gt_line: ['项目', '每100克', '营养素参考值%', '']
- predict_line: ['项目', '', '每100克营养素参考值%', '']
- '''
- self.gt_line = gt_line
- self.predict_line = predict_line
- self.result = []
- self.cols = 3
- def compare(self):
- for i, word in enumerate(self.predict_line):
- # print(self.predict_line, self.gt_line)
- if not self.gt_line:
- for j in range(self.cols):
- self.result.append(WordCompare(gt='', predict=word, cls=False))
- try:
- cls = True if word.strip() == self.gt_line[i].strip() else False
- if not word and not self.gt_line[i]:
- continue
- self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
- except IndexError:
- self.result.append(WordCompare(gt='', predict=word, cls=False))
- for i, word in enumerate(self.gt_line):
- if not self.predict_line:
- for j in range(self.cols):
- self.result.append(WordCompare(gt=word, predict='', cls=False))
- if i >= len(self.predict_line):
- self.result.append(WordCompare(gt=word, predict='', cls=False))
- def __repr__(self):
- return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
- # +
- class Table:
- def __init__(self, fn, gt_html, predict_html):
- self.fn = fn
- self.gt_html = gt_html
- self.predict_html = predict_html
- self.format_lines = []
- self.pair_lines = []
- self.result = self.get_result()
-
- @classmethod
- def from_dict(cls, tt):
- gt_html = tt['gt_html']
- predict_html = tt['predict_html']
- fn = Path(tt['fn'])
- t = cls(fn, gt_html, predict_html)
- return t
-
-
- def to_dict(self):
- return {
- 'fn': str(self.fn),
- 'gt_html': self.gt_html,
- 'predict_html': self.predict_html,
- 'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
- 'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for o in line] for line in self.result]
- }
-
- def display_image(self):
- im = Image.open(self.fn)
- return ImageOps.exif_transpose(im)
- def display_predict_html(self):
- return HTML(self.format_predict_html)
- def display_gt_html(self):
- return HTML(self.gt_html)
- @property
- def format_predict_html(self):
- if self.format_lines:
- header = '<html><body><table><tbody>'
- footer = '</tbody></table></body></html>'
- COLS = 3
- html = []
- for i, line in enumerate(self.format_lines):
- html.append('<tr>')
- for j in range(COLS):
- try:
- if i == 0 and '成分表' in line[j]:
- html.append('<td colspan="3">')
- html.append(line[j])
- html.append('</td>')
- break;
- else:
- html.append('<td>')
- html.append(line[j])
- html.append('</td>')
- except IndexError as e:
- print('format_predict_html', e)
- html.append('<td>')
- html.append('')
- html.append('</td>')
- continue
- html.append('</tr>')
- res = f'{header}{"".join(html)}{footer}'
- return res
- else:
- return self.predict_html
-
-
-
- @property
- def error_rate(self):
- corrects = 0
- errors = 0
- for line in self.result:
- for word in line:
- if word.cls:
- corrects += 1
- else:
- errors += 1
-
- total = corrects + errors
- return 0 if errors == 1 else errors / total
-
- @property
- def precision(self):
- corrects = 0
- p_len = 0
- for line in self.result:
- for word in line:
- if word.cls:
- corrects += 1
- if word.predict:
- p_len += 1
- return 0 if p_len == 0 else corrects / p_len
-
- @property
- def recall(self):
- corrects = 0
- g_len = 0
- for line in self.result:
- for word in line:
- if word.cls:
- corrects += 1
- if word.gt:
- g_len += 1
- return 0 if g_len == 0 else corrects / g_len
-
- @property
- def hmean(self):
- total = self.recall + self.precision
- return 0 if total == 0 else 2 * self.precision * self.recall / total
-
- def get_result(self):
- res = []
- self._generate_pair_lines()
- # print(self.pair_lines)
- for pair_line in self.pair_lines:
- pair_line.compare()
- res.append(pair_line.result)
- return res
-
- @rule3_decorator
- @rule2_decorator
- @rule1_decorator
- def _format_predict_line(self, predict_line):
- return predict_line
- def _get_lines(self, html) -> List[str]:
- '''
- res: ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
- '''
- if not html:
- return []
- rows = re.split('<tr>', html)
- res = []
- for row in rows:
- m = re.findall('<td.*>.*</td>', row)
- if m:
- res.extend(m)
- return res
-
- def _generate_pair_lines(self):
- gt_lines = self._get_lines(self.gt_html)
- predict_lines = self._get_lines(self.predict_html)
- gt_words_list = [self._split_to_words(line) for line in gt_lines]
- predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
- self.format_lines.extend(predict_words_list)
-
- DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
- dist_entries = []
- p = [False] * len(predict_words_list)
- g = [False] * len(gt_words_list)
- for i, p_line in enumerate(predict_words_list):
- for j, g_line in enumerate(gt_words_list):
- dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
- dist_entries.append(DistEntry(i=i, j=j, dist=dist))
- dist_entries.sort(key=lambda e: e.dist)
- for e in dist_entries:
- if not p[e.i] and not g[e.j]:
- p[e.i] = True
- g[e.j] = True
- self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
- for i in range(len(p)):
- if not p[i]:
- self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
- for i in range(len(g)):
- if not g[i]:
- self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
-
- def _match_gt_line(self, line, gt_lines):
- line = ''.join(line)
- min_dist = 9999
- res = []
- for i, gt_line in enumerate(gt_lines):
- gt_line = ''.join(gt_line)
- dist = Levenshtein.distance(gt_line, line)
- if dist < min_dist:
- min_dist = dist
- res = gt_lines[i]
- return res
- def _split_to_words(self, line):
- '''
- line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
- res: ['项目', '每100克', '营养素参考值%', '']
- '''
- res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
- return res
- def generate_tables(annos, root_path, i):
- predict_html = predict_table(annos, i, root_path=root_path)
- gt_html = get_gt(annos, i)
- fn = get_img_path(annos, i, root_path=root_path)
- table = Table(fn, gt_html=gt_html, predict_html=predict_html)
- return table
- class TableDataset:
- def __init__(self, root_path=None, anno_fn=None):
- if root_path and anno_fn:
- self.tables = []
- annos = read_annos(anno_fn)
- l = len(annos)
- # l = 10
- with ThreadPoolExecutor(max_workers=10) as executor:
- tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
- for table in tables:
- self.tables.append(table)
- else:
- self.tables = []
- @property
- def correct_num(self):
- return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
- @property
- def avg_error_rate(self):
- return np.mean([o.error_rate for o in self.tables])
-
- @property
- def avg_precision(self):
- return np.mean([o.precision for o in self.tables])
-
- @property
- def avg_recall(self):
- return np.mean([o.recall for o in self.tables])
-
- @property
- def avg_hmean(self):
- return np.mean([o.hmean for o in self.tables])
- def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
- tables = self.top_error_tables_by_threshold(th)
- fns = [t.fn for t in tables]
- if not (root_path / dst_path).exists():
- (root_path / dst_path).mkdir()
- for fn in tqdm(fns):
- src = fn
- dst = root_path / dst_path / fn.name
- shutil.copy2(src, dst)
- fns = [dst_path / t.fn.name for t in tables]
- annos = read_annos(anno_fn)
- annos = filter_annos(annos, fns)
- write_annos(annos, root_path / dst_path / 'gt.txt')
- def top_error_tables_by_threshold(self, th):
- res = []
- for r in self.tables:
- if r.error_rate >= th:
- res.append(r)
- return res
-
- def top_error_tables(self, k):
- tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
- return tables[:k]
- def to_json(self, fn):
- res = [o.to_dict() for o in self.tables]
- with open(fn, 'w', encoding='utf-8') as f:
- json.dump(res, f)
- @classmethod
- def from_json(cls, fn):
- res = cls()
- with open(fn, 'r') as f:
- ts = json.load(f)
- for o in ts:
- res.tables.append(Table.from_dict(o))
- return res
- # -
- root_path = Path('table-dataset')
- anno_fn = root_path / 'merge.txt'
- # anno_fn = root_path / 'hardimgs/gt.txt'
- # + jupyter={"outputs_hidden": true}
- import time
- cur_time = int(time.time())
- table_ds = TableDataset(root_path, anno_fn)
- json_filename = f'table_merge_{cur_time}.json'
- table_ds.to_json(json_filename)
- print(json_filename)
- # +
- # table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
- # -
- table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
- k = 100
- errors = table_ds.top_error_tables(k)
- tt = errors[9]
- for i, tt in enumerate(errors[:30]):
- print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
- display(tt.display_image())
- display(tt.display_predict_html())
- display(tt.display_gt_html())
- tt.display_image()
- tt.display_predict_html()
- tt.display_gt_html()
- tt.predict_html
- tt.gt_html
- tt.fn
- len(table_ds.top_error_tables_by_threshold(0.1))
- table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))
|