Browse Source

docs: 添加评估脚本

jingze_cheng 7 months ago
parent
commit
b17fdbdb5d
2 changed files with 587 additions and 1 deletions
  1. 583 0
      docs/scripts/表格结构模型评估.py
  2. 4 1
      docs/train_and_eval.md

+ 583 - 0
docs/scripts/表格结构模型评估.py

@@ -0,0 +1,583 @@
+# ---
+# jupyter:
+#   jupytext:
+#     text_representation:
+#       extension: .py
+#       format_name: light
+#       format_version: '1.5'
+#       jupytext_version: 1.16.1
+#   kernelspec:
+#     display_name: Python 3
+#     language: python
+#     name: python3
+# ---
+
+# Put these at the top of every notebook, to get automatic reloading and inline plotting
+# %reload_ext autoreload
+# %autoreload 2
+# %matplotlib inline
+
+import sys
+sys.path.insert(0, './PaddleOCR/')
+
+# +
+from paddleocr import PaddleOCR, PPStructure
+from pathlib import Path
+
+# from IPython.core.display import HTML
+from IPython.display import display, HTML
+from PIL import Image
+import re
+import cv2
+import base64
+from io import BytesIO
+from PIL import Image, ImageOps
+import numpy as np
+from matplotlib import pyplot as plt
+from sx_utils import *
+import json
+import threading
+from collections import namedtuple
+
+
+# +
+def table_res(img_path, ROTATE=-1):
+    im = cv2.imread(img_path)
+    if ROTATE >= 0:
+        im = cv2.rotate(im, ROTATE)
+    html = table_engine(im)[0]['res']['html']
+    return html
+
+def cal_html_to_chs(html):
+    res = []
+    rows = re.split('<tr>', html)
+    for row in rows:
+        row = re.split('<td>', row)
+        cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''),  row))
+        rec_str = ''.join(cells)
+        for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
+            rec_str = rec_str.replace(tag, '')
+            
+        res.append(rec_str)
+    
+    rec_res = ''.join(res).replace(' ','')
+    rec_res = re.split('<tdcolspan="\w+">', rec_res)
+    rec_res = ''.join(rec_res).replace(' ','')
+    print(rec_res)
+    return len(rec_res)
+
+
+# -
+
+lock = threading.Lock()
+def read_annos(path):
+    with open(path, 'r', encoding='utf-8') as f:
+        return f.readlines()
+
+
+from collections import namedtuple
+from typing import List
+import Levenshtein
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from tqdm import tqdm
+import requests
+import shutil
+
+# +
+from decorator import decorator
+@decorator
+def rule1_decorator(f, *args,**kwargs):
+    '''
+    predict_line = ['项目 ', '', '每100克营养素参考值%', '']
+    '''
+    predict_line = args[1]
+    predict_line = f(*args,**kwargs)
+    idx = predict_line.index('')
+    try:
+        if idx == 1:
+            if '项目' in predict_line[0] and '每100克' in predict_line[2]:
+                predict_line[1] = '每100克'
+                r = re.split('每100克', predict_line[2])
+                if len(r) == 2 and r[1]:
+                    predict_line[2] = r[1]
+    except IndexError as e:
+        print(e)
+    return predict_line
+
+@decorator
+def rule2_decorator(f, *args,**kwargs):
+    '''
+    predict_line = ['碳水化合物18.2克', '', '6%', '']
+    '''
+    predict_line = args[1]
+    predict_line = f(*args,**kwargs)
+    idx = predict_line.index('')
+    try:
+        if idx == 1:
+            if '化合物' in predict_line[0]:
+                r = re.split('化合物', predict_line[0])
+                predict_line[0] = '碳水化合物'
+                if len(r) == 2 and r[1]:
+                    predict_line[1] = r[1]
+    except IndexError as e:
+        print(e)
+    return predict_line
+            
+    
+@decorator
+def rule3_decorator(f, *args,**kwargs):
+    '''
+    ['患直质', '1.6克', '3%', '']
+    ['脂扇', '1.1', '19%', '']
+    ['碳水化合物', '勿18.2克', '6%', '']
+    
+    '''
+    predict_line = args[1]
+    predict_line = f(*args,**kwargs)
+    predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
+    predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
+    predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
+    predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
+    return predict_line
+
+
+# +
+WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
+def get_fn(annos, idx):
+    return json.loads(annos[idx])['filename']
+
+def get_gt(annos, idx):
+    return json.loads(annos[idx])['gt']
+
+def get_img_path(annos, idx, root_path):
+    return root_path / get_fn(annos, idx)
+
+def filter_annos(annos: List[str], fns: List[Path]):
+    res = []
+    pattern = '"filename": "(.*.jpg)"'
+    
+    for anno in annos:
+        for fn in fns: 
+            m = re.search(fn.name, anno)
+            if m:
+                sub = f'"filename": "{str(fn)}"'
+                new_anno = re.sub(pattern,sub, anno)
+                res.append(new_anno)
+    return res
+                
+    
+
+def _predict_table2(annos, idx,  root_path):
+    img_path = get_img_path(annos, idx,  root_path=root_path)
+    try:
+        lock.acquire()
+        img_path = str(img_path)
+        html = table_engine(img_path)[0]['res']['html']
+        return html
+    finally:
+        lock.release()
+    return None
+
+def _predict_table1(annos, idx,  root_path):
+    img_path = get_img_path(annos, idx,  root_path=root_path)
+    img_path = str(img_path)
+    img = cv2.imread(img_path)
+    img_str = cv2bgr_base64(img)
+    payload = {'image': img_str, 'det': 'conv'}
+    r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
+    res = r.json()
+    if 'status' in res and res['status'] == '000':
+        return res['result']['html']
+    else:
+        return None
+    
+def _predict_table3(annos, idx,  root_path):
+    img_path = get_img_path(annos, idx,  root_path=root_path)
+    img_path = str(img_path)
+    img = cv2.imread(img_path)
+    img_str = cv2bgr_base64(img)
+    payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
+    r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
+    res = r.json()
+    if 'status' in res and res['status'] == '000':
+        return res['result']['html']
+    else:
+        return None
+
+def predict_table(annos, idx, root_path):
+    max_retry = 3
+    def predict_with_retry(*args, retry=0):
+        if retry > max_retry:
+            raise RuntimeError("Max retry failed!")
+        if retry > 0:
+            print(f"retry = {retry}, idx = {idx}")
+        try:
+            return _predict_table3(*args)
+        except Exception as e:
+            print(f'request error: {e}')
+            return predict_with_retry(*args, retry=retry + 1)
+    return predict_with_retry(annos, idx, root_path)
+    
+class PairLine:
+    def __init__(self, gt_line, predict_line):
+        '''
+        gt_line: ['项目', '每100克', '营养素参考值%', '']
+        predict_line: ['项目', '', '每100克营养素参考值%', '']
+        '''
+        self.gt_line = gt_line
+        self.predict_line = predict_line
+        self.result = []
+        self.cols = 3
+
+    def compare(self):
+        for i, word in enumerate(self.predict_line):
+            # print(self.predict_line, self.gt_line)
+            if not self.gt_line:
+                for j in range(self.cols):
+                    self.result.append(WordCompare(gt='', predict=word, cls=False))
+            try:
+                cls = True if word.strip() == self.gt_line[i].strip() else False
+                if not word and not self.gt_line[i]:
+                    continue
+                self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
+            except IndexError:
+                self.result.append(WordCompare(gt='', predict=word, cls=False))
+        for i, word in enumerate(self.gt_line):
+            if not self.predict_line:
+                for j in range(self.cols):
+                    self.result.append(WordCompare(gt=word, predict='', cls=False))
+            if i >= len(self.predict_line):
+                self.result.append(WordCompare(gt=word, predict='', cls=False))
+
+
+    def __repr__(self):
+        return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
+
+
+# +
+class Table:
+    def __init__(self, fn, gt_html, predict_html):
+        self.fn = fn
+        self.gt_html = gt_html
+        self.predict_html = predict_html
+        self.format_lines = []
+        self.pair_lines = []
+        self.result = self.get_result()
+
+    
+    @classmethod
+    def from_dict(cls, tt):
+        gt_html = tt['gt_html']
+        predict_html = tt['predict_html']
+        fn = Path(tt['fn'])
+        t = cls(fn, gt_html, predict_html)
+        return t
+        
+    
+    def to_dict(self):
+        return {
+            'fn': str(self.fn),
+            'gt_html': self.gt_html,
+            'predict_html': self.predict_html,
+            'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
+            'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for  o in line] for line in self.result]
+        }
+    
+    def display_image(self):
+        im = Image.open(self.fn)
+        return ImageOps.exif_transpose(im)
+
+    def display_predict_html(self):
+        return HTML(self.format_predict_html)
+
+    def display_gt_html(self):
+        return HTML(self.gt_html)
+
+    @property
+    def format_predict_html(self):
+        if self.format_lines:
+            header = '<html><body><table><tbody>'
+            footer = '</tbody></table></body></html>'
+            COLS = 3
+            html = []
+            for i, line in enumerate(self.format_lines):
+                html.append('<tr>')
+                for j in range(COLS):
+                    try:
+                        if i == 0 and '成分表' in line[j]:
+                            html.append('<td colspan="3">')
+                            html.append(line[j])
+                            html.append('</td>')
+                            break;
+                        else:
+                            html.append('<td>')
+                            html.append(line[j])
+                            html.append('</td>')
+                    except IndexError as e:
+                        print('format_predict_html', e)
+                        html.append('<td>')
+                        html.append('')
+                        html.append('</td>')
+                        continue
+                html.append('</tr>')
+            res = f'{header}{"".join(html)}{footer}'
+            return res
+        else:
+            return self.predict_html
+        
+    
+    
+    @property
+    def error_rate(self):
+        corrects = 0
+        errors = 0
+        for line in self.result:
+            for word in line:
+                if word.cls:
+                    corrects += 1
+                else:
+                    errors += 1
+        
+        total = corrects + errors
+        return 0 if errors == 1 else errors / total
+    
+    @property
+    def precision(self):
+        corrects = 0
+        p_len = 0
+        for line in self.result:
+            for word in line:
+                if word.cls:
+                    corrects += 1
+                if word.predict:
+                    p_len += 1
+        return 0 if p_len == 0 else corrects / p_len
+    
+    @property
+    def recall(self):
+        corrects = 0
+        g_len = 0
+        for line in self.result:
+            for word in line:
+                if word.cls:
+                    corrects += 1
+                if word.gt:
+                    g_len += 1
+        return 0 if g_len == 0 else corrects / g_len
+    
+    @property
+    def hmean(self):
+        total = self.recall + self.precision
+        return 0 if total == 0 else 2 * self.precision * self.recall / total
+
+    
+    def get_result(self):
+        res = []
+        self._generate_pair_lines()
+        # print(self.pair_lines)
+        for pair_line in self.pair_lines:
+            pair_line.compare()
+            res.append(pair_line.result)
+        return res
+
+    
+    @rule3_decorator
+    @rule2_decorator
+    @rule1_decorator
+    def _format_predict_line(self, predict_line):
+        return predict_line
+
+    def _get_lines(self, html) -> List[str]:
+        '''
+        res:  ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
+        '''
+        if not html:
+            return []
+        rows = re.split('<tr>', html)
+        res = []
+        for row in rows:
+            m = re.findall('<td.*>.*</td>', row)
+            if m:
+                res.extend(m)
+        return res
+    
+
+    def _generate_pair_lines(self):
+        gt_lines = self._get_lines(self.gt_html)
+        predict_lines = self._get_lines(self.predict_html)
+        gt_words_list = [self._split_to_words(line) for line in gt_lines]
+        predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
+        self.format_lines.extend(predict_words_list)
+        
+        DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
+        dist_entries = []
+        p = [False] * len(predict_words_list)
+        g = [False] * len(gt_words_list)
+        for i, p_line in enumerate(predict_words_list):
+            for j, g_line in enumerate(gt_words_list):
+                dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
+                dist_entries.append(DistEntry(i=i, j=j, dist=dist))
+        dist_entries.sort(key=lambda e: e.dist)
+        for e in dist_entries:
+            if not p[e.i] and not g[e.j]:
+                p[e.i] = True
+                g[e.j] = True
+                self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
+        for i in range(len(p)):
+            if not p[i]:
+                self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
+        for i in range(len(g)):
+            if not g[i]:
+                self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
+
+                
+    def _match_gt_line(self, line, gt_lines):
+        line = ''.join(line)
+        min_dist = 9999
+        res = []
+        for i, gt_line in enumerate(gt_lines):
+            gt_line = ''.join(gt_line)
+            dist = Levenshtein.distance(gt_line, line)
+            if dist < min_dist:
+                min_dist = dist
+                res = gt_lines[i]
+        return res
+
+
+    def _split_to_words(self, line):
+        '''
+        line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
+        res: ['项目', '每100克', '营养素参考值%', '']
+        '''
+        res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
+        return res
+
+def generate_tables(annos, root_path, i):
+    predict_html = predict_table(annos, i, root_path=root_path)
+    gt_html = get_gt(annos, i)
+    fn = get_img_path(annos, i, root_path=root_path)
+    table = Table(fn, gt_html=gt_html, predict_html=predict_html)
+    return table
+
+class TableDataset:
+    def __init__(self, root_path=None, anno_fn=None):
+        if root_path and anno_fn:
+            self.tables = []
+            annos = read_annos(anno_fn)
+            l = len(annos)
+            # l = 10
+            with ThreadPoolExecutor(max_workers=10) as executor:
+                tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
+            for table in tables:
+                self.tables.append(table)
+        else:
+            self.tables = []
+    @property
+    def correct_num(self):
+        return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
+
+    @property
+    def avg_error_rate(self):
+        return np.mean([o.error_rate for o in self.tables])
+    
+    @property
+    def avg_precision(self):
+        return np.mean([o.precision for o in self.tables])
+    
+    @property
+    def avg_recall(self):
+        return np.mean([o.recall for o in self.tables])
+    
+    @property
+    def avg_hmean(self):
+        return np.mean([o.hmean for o in self.tables])
+
+    def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
+        tables = self.top_error_tables_by_threshold(th)
+        fns = [t.fn for t in tables]
+        if not (root_path / dst_path).exists():
+            (root_path / dst_path).mkdir()
+        for fn in tqdm(fns):
+            src = fn
+            dst = root_path / dst_path / fn.name
+            shutil.copy2(src, dst)
+        fns = [dst_path / t.fn.name for t in tables]
+        annos = read_annos(anno_fn)
+        annos = filter_annos(annos, fns)
+        write_annos(annos, root_path / dst_path / 'gt.txt')
+
+    def top_error_tables_by_threshold(self, th):
+        res = []
+        for r in self.tables:
+            if r.error_rate >= th:
+                res.append(r)
+        return res
+        
+
+    def top_error_tables(self, k):
+        tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
+        return tables[:k]
+
+    def to_json(self, fn):
+        res = [o.to_dict() for o in self.tables]
+        with open(fn, 'w', encoding='utf-8') as f:
+            json.dump(res, f)
+
+    @classmethod
+    def from_json(cls, fn):
+        res = cls()
+        with open(fn, 'r') as f:
+            ts = json.load(f)
+        for o in ts:
+            res.tables.append(Table.from_dict(o))
+        return res
+
+# -
+
+root_path = Path('table-dataset')
+anno_fn = root_path / 'merge.txt'
+# anno_fn = root_path / 'hardimgs/gt.txt'
+
+# + jupyter={"outputs_hidden": true}
+import time
+cur_time = int(time.time())
+table_ds = TableDataset(root_path, anno_fn)
+json_filename = f'table_merge_{cur_time}.json'
+table_ds.to_json(json_filename)
+print(json_filename)
+
+# +
+# table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
+# -
+
+table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
+
+k = 100
+errors = table_ds.top_error_tables(k)
+
+tt = errors[9]
+
+for i, tt in enumerate(errors[:30]):
+    print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
+    display(tt.display_image())
+    display(tt.display_predict_html())
+    display(tt.display_gt_html())
+
+tt.display_image()
+
+tt.display_predict_html()
+
+tt.display_gt_html()
+
+tt.predict_html
+
+tt.gt_html
+
+tt.fn
+
+len(table_ds.top_error_tables_by_threshold(0.1))
+
+table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))
+
+

+ 4 - 1
docs/train_and_eval.md

@@ -110,4 +110,7 @@ PaddleOCR 默认的配置文件对应 **batch_size<sub>default</sub>=8**,**GPU
 
 ### 评估
 
-请参考官方文档的[评估方法](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#3-%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0%E4%B8%8E%E9%A2%84%E6%B5%8B)。
+评估脚本:[表格结构模型评估.py](./scripts/表格结构模型评估.py)<br>
+可使用 [jupytext](https://github.com/mwouts/jupytext) 将其转换为 notebook 文件:`jupytext --to ipynb notebook.py`
+
+更多请参考官方文档的[评估方法](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#3-%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0%E4%B8%8E%E9%A2%84%E6%B5%8B)。