jingze_cheng пре 7 месеци
родитељ
комит
a4d075ac53
5 измењених фајлова са 3 додато и 805 уклоњено
  1. 3 3
      README.md
  2. 0 34
      docs/prepare_data.md
  3. 0 69
      docs/scripts/table_model.sh
  4. 0 583
      docs/scripts/表格结构模型评估.py
  5. 0 116
      docs/train_and_eval.md

+ 3 - 3
README.md

@@ -28,9 +28,9 @@ make all
 
 ## 模型说明
 
-| 类别         | 名称                                                                                                                                                                                      | 配置                       | 训练说明                                           |
-| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- | -------------------------------------------------- |
-| 表格结构检测 | [ch_ppstructure_mobile_v2.0_SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) | [./server.py](./server.py) | [表格结构模型训练与评估](./docs/train_and_eval.md) |
+| 类别         | 名称                                                                                                                                                                                      | 配置                       |
+| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- |
+| 表格结构检测 | [ch_ppstructure_mobile_v2.0_SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) | [./server.py](./server.py) |
 
 如果更新了模型权重,请同时修改创建镜像时的下载地址:
 

+ 0 - 34
docs/prepare_data.md

@@ -1,34 +0,0 @@
-# 表格数据集准备
-
-表格数据集的图片由版面数据集切图得到,并经过旋转校正预处理。
-
-表格数据集使用 PPOCRLabel 进行标注,标注流程请查看官方文档:[表格标注](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/PPOCRLabel/README_ch.md#22-%E8%A1%A8%E6%A0%BC%E6%A0%87%E6%B3%A8%E8%A7%86%E9%A2%91%E6%BC%94%E7%A4%BA)。
-
-## 数据集格式
-
-数据集为[PaddleOCR 表格识别模型数据集格式](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#11-%E6%95%B0%E6%8D%AE%E9%9B%86%E6%A0%BC%E5%BC%8F),包含表格结构和每个 Cell 的信息:
-
-```text
-{
-   'filename': PMC5755158_010_01.png,                               # 图像名
-   'html': {
-     'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]},     # 表格的HTML字符串
-     'cells': [
-       {
-         'tokens': ['P', 'a', 'd', 'd', 'l', 'e'],                  # 表格中的单个文本
-         'bbox': [x0, y0, x1, y1]                                   # 表格中的单个文本的坐标
-       }
-     ]
-   }
-}
-```
-
-## 下载数据集
-
-[伊利冷饮版面-表格结构数据集](https://huggingface.co/datasets/BethanThornton/table-dataset)
-
-## 调整数据
-
-可使用 [layout-ocr-data-utils](https://gogs.soaringnova.com/yili-ocr/layout-ocr-data-utils) 调整表格数据集,如合并切分,数据增强等,以及进行数据可视化。具体请查看该工具的文档。
-
-可使用 [TableGeneration](https://github.com/WenmuZhou/TableGeneration) 生成表格图像。表格所需的数据量较大,官方推荐至少准备 2000 张用于模型微调。

+ 0 - 69
docs/scripts/table_model.sh

@@ -1,69 +0,0 @@
-#!/bin/bash
-# shellcheck disable=SC2155
-
-set -eux
-
-# SLANet_ch 模型是 PaddleOCR 目前最优的中文表格预训练模型
-# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md
-readonly MODEL_NAME="SLANet_ch"
-readonly MODEL_CONF="configs/table/${MODEL_NAME}.yml"
-# 推理参数
-readonly INFER_IMG_DIR="train_data/table-dataset/conv.v16i/all"
-readonly CUR_TIME=$(date "+%s")
-
-edit_model() {
-    vim "${MODEL_CONF}"
-}
-
-train_model() {
-    python3 tools/train.py -c "${MODEL_CONF}"
-}
-
-train_model_distr() {
-    python3 \
-        -m paddle.distributed.launch \
-        --gpus '0,1,2,3,4,5,6,7' \
-        tools/train.py -c "${MODEL_CONF}"
-}
-
-export_model() {
-    python3 tools/export_model.py \
-        -c "${MODEL_CONF}" \
-        -o Global.pretrained_model="./output/${MODEL_NAME}/best_accuracy" \
-        Global.save_inference_dir="./inference/${MODEL_NAME}"
-}
-
-infer_model() {
-    python3 ppstructure/table/predict_structure.py \
-        --table_model_dir=inference/"${MODEL_NAME}" \
-        --rec_char_dict_path="./ppocr/utils/ppocr_keys_v1.txt" \
-        --table_char_dict_path="./ppocr/utils/dict/table_structure_dict_ch.txt" \
-        --image_dir="${INFER_IMG_DIR}" \
-        --output="inference_results/${MODEL_NAME}_${CUR_TIME}"
-}
-
-main() {
-    case "${1}" in
-    edit)
-        edit_model
-        ;;
-    train)
-        train_model
-        ;;
-    train_distr)
-        train_model_distr
-        ;;
-    export)
-        export_model
-        ;;
-    infer)
-        infer_model
-        ;;
-    *)
-        echo "Invalid option: ${1}"
-        exit 1
-        ;;
-    esac
-}
-
-main "$@"

+ 0 - 583
docs/scripts/表格结构模型评估.py

@@ -1,583 +0,0 @@
-# ---
-# jupyter:
-#   jupytext:
-#     text_representation:
-#       extension: .py
-#       format_name: light
-#       format_version: '1.5'
-#       jupytext_version: 1.16.1
-#   kernelspec:
-#     display_name: Python 3
-#     language: python
-#     name: python3
-# ---
-
-# Put these at the top of every notebook, to get automatic reloading and inline plotting
-# %reload_ext autoreload
-# %autoreload 2
-# %matplotlib inline
-
-import sys
-sys.path.insert(0, './PaddleOCR/')
-
-# +
-from paddleocr import PaddleOCR, PPStructure
-from pathlib import Path
-
-# from IPython.core.display import HTML
-from IPython.display import display, HTML
-from PIL import Image
-import re
-import cv2
-import base64
-from io import BytesIO
-from PIL import Image, ImageOps
-import numpy as np
-from matplotlib import pyplot as plt
-from sx_utils import *
-import json
-import threading
-from collections import namedtuple
-
-
-# +
-def table_res(img_path, ROTATE=-1):
-    im = cv2.imread(img_path)
-    if ROTATE >= 0:
-        im = cv2.rotate(im, ROTATE)
-    html = table_engine(im)[0]['res']['html']
-    return html
-
-def cal_html_to_chs(html):
-    res = []
-    rows = re.split('<tr>', html)
-    for row in rows:
-        row = re.split('<td>', row)
-        cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''),  row))
-        rec_str = ''.join(cells)
-        for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
-            rec_str = rec_str.replace(tag, '')
-            
-        res.append(rec_str)
-    
-    rec_res = ''.join(res).replace(' ','')
-    rec_res = re.split('<tdcolspan="\w+">', rec_res)
-    rec_res = ''.join(rec_res).replace(' ','')
-    print(rec_res)
-    return len(rec_res)
-
-
-# -
-
-lock = threading.Lock()
-def read_annos(path):
-    with open(path, 'r', encoding='utf-8') as f:
-        return f.readlines()
-
-
-from collections import namedtuple
-from typing import List
-import Levenshtein
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
-from tqdm import tqdm
-import requests
-import shutil
-
-# +
-from decorator import decorator
-@decorator
-def rule1_decorator(f, *args,**kwargs):
-    '''
-    predict_line = ['项目 ', '', '每100克营养素参考值%', '']
-    '''
-    predict_line = args[1]
-    predict_line = f(*args,**kwargs)
-    idx = predict_line.index('')
-    try:
-        if idx == 1:
-            if '项目' in predict_line[0] and '每100克' in predict_line[2]:
-                predict_line[1] = '每100克'
-                r = re.split('每100克', predict_line[2])
-                if len(r) == 2 and r[1]:
-                    predict_line[2] = r[1]
-    except IndexError as e:
-        print(e)
-    return predict_line
-
-@decorator
-def rule2_decorator(f, *args,**kwargs):
-    '''
-    predict_line = ['碳水化合物18.2克', '', '6%', '']
-    '''
-    predict_line = args[1]
-    predict_line = f(*args,**kwargs)
-    idx = predict_line.index('')
-    try:
-        if idx == 1:
-            if '化合物' in predict_line[0]:
-                r = re.split('化合物', predict_line[0])
-                predict_line[0] = '碳水化合物'
-                if len(r) == 2 and r[1]:
-                    predict_line[1] = r[1]
-    except IndexError as e:
-        print(e)
-    return predict_line
-            
-    
-@decorator
-def rule3_decorator(f, *args,**kwargs):
-    '''
-    ['患直质', '1.6克', '3%', '']
-    ['脂扇', '1.1', '19%', '']
-    ['碳水化合物', '勿18.2克', '6%', '']
-    
-    '''
-    predict_line = args[1]
-    predict_line = f(*args,**kwargs)
-    predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
-    predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
-    predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
-    predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
-    return predict_line
-
-
-# +
-WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
-def get_fn(annos, idx):
-    return json.loads(annos[idx])['filename']
-
-def get_gt(annos, idx):
-    return json.loads(annos[idx])['gt']
-
-def get_img_path(annos, idx, root_path):
-    return root_path / get_fn(annos, idx)
-
-def filter_annos(annos: List[str], fns: List[Path]):
-    res = []
-    pattern = '"filename": "(.*.jpg)"'
-    
-    for anno in annos:
-        for fn in fns: 
-            m = re.search(fn.name, anno)
-            if m:
-                sub = f'"filename": "{str(fn)}"'
-                new_anno = re.sub(pattern,sub, anno)
-                res.append(new_anno)
-    return res
-                
-    
-
-def _predict_table2(annos, idx,  root_path):
-    img_path = get_img_path(annos, idx,  root_path=root_path)
-    try:
-        lock.acquire()
-        img_path = str(img_path)
-        html = table_engine(img_path)[0]['res']['html']
-        return html
-    finally:
-        lock.release()
-    return None
-
-def _predict_table1(annos, idx,  root_path):
-    img_path = get_img_path(annos, idx,  root_path=root_path)
-    img_path = str(img_path)
-    img = cv2.imread(img_path)
-    img_str = cv2bgr_base64(img)
-    payload = {'image': img_str, 'det': 'conv'}
-    r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
-    res = r.json()
-    if 'status' in res and res['status'] == '000':
-        return res['result']['html']
-    else:
-        return None
-    
-def _predict_table3(annos, idx,  root_path):
-    img_path = get_img_path(annos, idx,  root_path=root_path)
-    img_path = str(img_path)
-    img = cv2.imread(img_path)
-    img_str = cv2bgr_base64(img)
-    payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
-    r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
-    res = r.json()
-    if 'status' in res and res['status'] == '000':
-        return res['result']['html']
-    else:
-        return None
-
-def predict_table(annos, idx, root_path):
-    max_retry = 3
-    def predict_with_retry(*args, retry=0):
-        if retry > max_retry:
-            raise RuntimeError("Max retry failed!")
-        if retry > 0:
-            print(f"retry = {retry}, idx = {idx}")
-        try:
-            return _predict_table3(*args)
-        except Exception as e:
-            print(f'request error: {e}')
-            return predict_with_retry(*args, retry=retry + 1)
-    return predict_with_retry(annos, idx, root_path)
-    
-class PairLine:
-    def __init__(self, gt_line, predict_line):
-        '''
-        gt_line: ['项目', '每100克', '营养素参考值%', '']
-        predict_line: ['项目', '', '每100克营养素参考值%', '']
-        '''
-        self.gt_line = gt_line
-        self.predict_line = predict_line
-        self.result = []
-        self.cols = 3
-
-    def compare(self):
-        for i, word in enumerate(self.predict_line):
-            # print(self.predict_line, self.gt_line)
-            if not self.gt_line:
-                for j in range(self.cols):
-                    self.result.append(WordCompare(gt='', predict=word, cls=False))
-            try:
-                cls = True if word.strip() == self.gt_line[i].strip() else False
-                if not word and not self.gt_line[i]:
-                    continue
-                self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
-            except IndexError:
-                self.result.append(WordCompare(gt='', predict=word, cls=False))
-        for i, word in enumerate(self.gt_line):
-            if not self.predict_line:
-                for j in range(self.cols):
-                    self.result.append(WordCompare(gt=word, predict='', cls=False))
-            if i >= len(self.predict_line):
-                self.result.append(WordCompare(gt=word, predict='', cls=False))
-
-
-    def __repr__(self):
-        return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
-
-
-# +
-class Table:
-    def __init__(self, fn, gt_html, predict_html):
-        self.fn = fn
-        self.gt_html = gt_html
-        self.predict_html = predict_html
-        self.format_lines = []
-        self.pair_lines = []
-        self.result = self.get_result()
-
-    
-    @classmethod
-    def from_dict(cls, tt):
-        gt_html = tt['gt_html']
-        predict_html = tt['predict_html']
-        fn = Path(tt['fn'])
-        t = cls(fn, gt_html, predict_html)
-        return t
-        
-    
-    def to_dict(self):
-        return {
-            'fn': str(self.fn),
-            'gt_html': self.gt_html,
-            'predict_html': self.predict_html,
-            'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
-            'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for  o in line] for line in self.result]
-        }
-    
-    def display_image(self):
-        im = Image.open(self.fn)
-        return ImageOps.exif_transpose(im)
-
-    def display_predict_html(self):
-        return HTML(self.format_predict_html)
-
-    def display_gt_html(self):
-        return HTML(self.gt_html)
-
-    @property
-    def format_predict_html(self):
-        if self.format_lines:
-            header = '<html><body><table><tbody>'
-            footer = '</tbody></table></body></html>'
-            COLS = 3
-            html = []
-            for i, line in enumerate(self.format_lines):
-                html.append('<tr>')
-                for j in range(COLS):
-                    try:
-                        if i == 0 and '成分表' in line[j]:
-                            html.append('<td colspan="3">')
-                            html.append(line[j])
-                            html.append('</td>')
-                            break;
-                        else:
-                            html.append('<td>')
-                            html.append(line[j])
-                            html.append('</td>')
-                    except IndexError as e:
-                        print('format_predict_html', e)
-                        html.append('<td>')
-                        html.append('')
-                        html.append('</td>')
-                        continue
-                html.append('</tr>')
-            res = f'{header}{"".join(html)}{footer}'
-            return res
-        else:
-            return self.predict_html
-        
-    
-    
-    @property
-    def error_rate(self):
-        corrects = 0
-        errors = 0
-        for line in self.result:
-            for word in line:
-                if word.cls:
-                    corrects += 1
-                else:
-                    errors += 1
-        
-        total = corrects + errors
-        return 0 if errors == 1 else errors / total
-    
-    @property
-    def precision(self):
-        corrects = 0
-        p_len = 0
-        for line in self.result:
-            for word in line:
-                if word.cls:
-                    corrects += 1
-                if word.predict:
-                    p_len += 1
-        return 0 if p_len == 0 else corrects / p_len
-    
-    @property
-    def recall(self):
-        corrects = 0
-        g_len = 0
-        for line in self.result:
-            for word in line:
-                if word.cls:
-                    corrects += 1
-                if word.gt:
-                    g_len += 1
-        return 0 if g_len == 0 else corrects / g_len
-    
-    @property
-    def hmean(self):
-        total = self.recall + self.precision
-        return 0 if total == 0 else 2 * self.precision * self.recall / total
-
-    
-    def get_result(self):
-        res = []
-        self._generate_pair_lines()
-        # print(self.pair_lines)
-        for pair_line in self.pair_lines:
-            pair_line.compare()
-            res.append(pair_line.result)
-        return res
-
-    
-    @rule3_decorator
-    @rule2_decorator
-    @rule1_decorator
-    def _format_predict_line(self, predict_line):
-        return predict_line
-
-    def _get_lines(self, html) -> List[str]:
-        '''
-        res:  ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
-        '''
-        if not html:
-            return []
-        rows = re.split('<tr>', html)
-        res = []
-        for row in rows:
-            m = re.findall('<td.*>.*</td>', row)
-            if m:
-                res.extend(m)
-        return res
-    
-
-    def _generate_pair_lines(self):
-        gt_lines = self._get_lines(self.gt_html)
-        predict_lines = self._get_lines(self.predict_html)
-        gt_words_list = [self._split_to_words(line) for line in gt_lines]
-        predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
-        self.format_lines.extend(predict_words_list)
-        
-        DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
-        dist_entries = []
-        p = [False] * len(predict_words_list)
-        g = [False] * len(gt_words_list)
-        for i, p_line in enumerate(predict_words_list):
-            for j, g_line in enumerate(gt_words_list):
-                dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
-                dist_entries.append(DistEntry(i=i, j=j, dist=dist))
-        dist_entries.sort(key=lambda e: e.dist)
-        for e in dist_entries:
-            if not p[e.i] and not g[e.j]:
-                p[e.i] = True
-                g[e.j] = True
-                self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
-        for i in range(len(p)):
-            if not p[i]:
-                self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
-        for i in range(len(g)):
-            if not g[i]:
-                self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
-
-                
-    def _match_gt_line(self, line, gt_lines):
-        line = ''.join(line)
-        min_dist = 9999
-        res = []
-        for i, gt_line in enumerate(gt_lines):
-            gt_line = ''.join(gt_line)
-            dist = Levenshtein.distance(gt_line, line)
-            if dist < min_dist:
-                min_dist = dist
-                res = gt_lines[i]
-        return res
-
-
-    def _split_to_words(self, line):
-        '''
-        line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
-        res: ['项目', '每100克', '营养素参考值%', '']
-        '''
-        res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
-        return res
-
-def generate_tables(annos, root_path, i):
-    predict_html = predict_table(annos, i, root_path=root_path)
-    gt_html = get_gt(annos, i)
-    fn = get_img_path(annos, i, root_path=root_path)
-    table = Table(fn, gt_html=gt_html, predict_html=predict_html)
-    return table
-
-class TableDataset:
-    def __init__(self, root_path=None, anno_fn=None):
-        if root_path and anno_fn:
-            self.tables = []
-            annos = read_annos(anno_fn)
-            l = len(annos)
-            # l = 10
-            with ThreadPoolExecutor(max_workers=10) as executor:
-                tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
-            for table in tables:
-                self.tables.append(table)
-        else:
-            self.tables = []
-    @property
-    def correct_num(self):
-        return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
-
-    @property
-    def avg_error_rate(self):
-        return np.mean([o.error_rate for o in self.tables])
-    
-    @property
-    def avg_precision(self):
-        return np.mean([o.precision for o in self.tables])
-    
-    @property
-    def avg_recall(self):
-        return np.mean([o.recall for o in self.tables])
-    
-    @property
-    def avg_hmean(self):
-        return np.mean([o.hmean for o in self.tables])
-
-    def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
-        tables = self.top_error_tables_by_threshold(th)
-        fns = [t.fn for t in tables]
-        if not (root_path / dst_path).exists():
-            (root_path / dst_path).mkdir()
-        for fn in tqdm(fns):
-            src = fn
-            dst = root_path / dst_path / fn.name
-            shutil.copy2(src, dst)
-        fns = [dst_path / t.fn.name for t in tables]
-        annos = read_annos(anno_fn)
-        annos = filter_annos(annos, fns)
-        write_annos(annos, root_path / dst_path / 'gt.txt')
-
-    def top_error_tables_by_threshold(self, th):
-        res = []
-        for r in self.tables:
-            if r.error_rate >= th:
-                res.append(r)
-        return res
-        
-
-    def top_error_tables(self, k):
-        tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
-        return tables[:k]
-
-    def to_json(self, fn):
-        res = [o.to_dict() for o in self.tables]
-        with open(fn, 'w', encoding='utf-8') as f:
-            json.dump(res, f)
-
-    @classmethod
-    def from_json(cls, fn):
-        res = cls()
-        with open(fn, 'r') as f:
-            ts = json.load(f)
-        for o in ts:
-            res.tables.append(Table.from_dict(o))
-        return res
-
-# -
-
-root_path = Path('table-dataset')
-anno_fn = root_path / 'merge.txt'
-# anno_fn = root_path / 'hardimgs/gt.txt'
-
-# + jupyter={"outputs_hidden": true}
-import time
-cur_time = int(time.time())
-table_ds = TableDataset(root_path, anno_fn)
-json_filename = f'table_merge_{cur_time}.json'
-table_ds.to_json(json_filename)
-print(json_filename)
-
-# +
-# table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
-# -
-
-table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
-
-k = 100
-errors = table_ds.top_error_tables(k)
-
-tt = errors[9]
-
-for i, tt in enumerate(errors[:30]):
-    print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
-    display(tt.display_image())
-    display(tt.display_predict_html())
-    display(tt.display_gt_html())
-
-tt.display_image()
-
-tt.display_predict_html()
-
-tt.display_gt_html()
-
-tt.predict_html
-
-tt.gt_html
-
-tt.fn
-
-len(table_ds.top_error_tables_by_threshold(0.1))
-
-table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))
-
-

+ 0 - 116
docs/train_and_eval.md

@@ -1,116 +0,0 @@
-# 表格结构模型训练与评估
-
-[PaddleOCR 的表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/table/README_ch.md)流程中包含三个模型:表格结构预测模型,单行文本检测模型,单行文本识别模型。我们目前对表格结构预测模型进行了训练。
-
-## 准备数据集
-
-请参考:[数据集准备](./prepare_data.md)
-
-## 准备环境
-
-克隆 PaddleOCR 仓库,进入仓库目录,安装依赖:
-
-```bash
-git clone --depth 1 https://github.com/PaddlePaddle/PaddleOCR.git
-cd PaddleOCR
-pip install -r requirements.txt
-```
-
-PaddleOCR 训练数据的默认存储目录是 `PaddleOCR/train_data`,我们将数据集下载到本地后,可以拷贝数据集或创建软链接到该目录:
-
-```bash
-cp -r /path/to/table-dataset ./train_data/table-dataset
-# 或者
-ln -sf /path/to/table-dataset ./train_data/table-dataset
-```
-
-请将我们的训练脚本 [table_model.sh](./scripts/table_model.sh) 拷贝至 `PaddleOCR` 目录:
-
-```bash
-cp /path/to/table_model.sh ./table_model.sh
-```
-
-PaddleOCR 对训练过程做了模块化,如果要训练不同的模型,我们只需要在脚本开头更换配置文件。
-
-## 表格结构预测模型训练与评估
-
-### 训练
-
-以我们目前使用的 SLANet 模型为例(官方文档:[表格识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B),配置文件:[SLANet_ch.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/configs/table/SLANet_ch.yml)),修改配置文件如下:
-
-```bash
-$ cat configs/table/SLANet_ch.yml
-Global:
-  use_gpu: True
-  # 训练轮数
-  epoch_num: 400
-  # 预训练模型文件
-  pretrained_model: ./pretrain_models/ch_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
-...
-
-Optimizer:
-  name: Adam
-  beta1: 0.9
-  beta2: 0.999
-  clip_norm: 5.0
-  lr:
-    # 学习率
-    learning_rate: 0.001
-...
-
-Train:
-  dataset:
-    name: PubTabDataSet
-    # 训练集目录
-    data_dir: train_data/table-dataset/artificial
-    # 训练集标注文件
-    label_file_list: [train_data/table-dataset/artificial/train.txt]
-...
-
-Eval:
-  dataset:
-    name: PubTabDataSet
-    # 验证集目录
-    data_dir: train_data/table-dataset/artificial/
-    # 验证集标注文件
-    label_file_list: [train_data/table-dataset/artificial/test.txt]
-...
-```
-
-其中,学习率 `learning_rate` (记为`lr`) 需要按运行时 `GPU卡数` (记为`GPU_number`) 和 `batch_size_per_card` (记为`batch_size`) 进行调整,公式为:
-
-**lr<sub>new</sub> = lr<sub>default</sub> \* (batch_size<sub>new</sub> \* GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> \* GPU_number<sub>default</sub>)**
-
-PaddleOCR 默认的配置文件对应 **batch_size<sub>default</sub>=8**,**GPU_number<sub>default</sub>=8**。
-
-更详细的参数调整说明,请参考官方文档:[模型微调](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/finetune.md)。更详细的配置项含义,请参考官方文档:[配置文件内容与生成](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/config.md)。
-
-训练模型:
-
-```bash
-# 单卡训练
-./table_model.sh train
-# 多卡训练
-./table_model.sh train_distr
-```
-
-导出模型:
-
-```bash
-./table_model.sh export
-```
-
-使用导出的模型推理:
-
-```bash
-./table_model.sh infer
-```
-
-更详细的模型训练,推理,部署说明请参考:[官方文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md)
-
-### 评估
-
-评估脚本:[表格结构模型评估.py](./scripts/表格结构模型评估.py)<br>
-可使用 [jupytext](https://github.com/mwouts/jupytext) 将其转换为 notebook 文件:`jupytext --to ipynb notebook.py`
-
-更多请参考官方文档的[评估方法](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#3-%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0%E4%B8%8E%E9%A2%84%E6%B5%8B)。