Browse Source

docs: 添加训练文档

jingze_cheng 7 months ago
parent
commit
8cc64989d5
4 changed files with 230 additions and 4 deletions
  1. 4 4
      README.md
  2. 49 0
      docs/prepare_data.md
  3. 69 0
      docs/scripts/table_model.sh
  4. 108 0
      docs/train_and_eval.md

+ 4 - 4
README.md

@@ -26,11 +26,11 @@ python -m unittest discover testing '*_test.py' -v
 make all
 ```
 
-## 模型配置
+## 模型说明
 
-| 模型类别     | 模型名称                                                                                                                                                                                  | 模型配置                   |
-| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- |
-| 表格结构检测 | [ch_ppstructure_mobile_v2.0_SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) | [./server.py](./server.py) |
+| 类别         | 名称                                                                                                                                                                                      | 配置                       | 训练说明                                           |
+| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------- | -------------------------------------------------- |
+| 表格结构检测 | [ch_ppstructure_mobile_v2.0_SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) | [./server.py](./server.py) | [表格结构模型训练与评估](./docs/train_and_eval.md) |
 
 如果更新了模型权重,请同时修改创建镜像时的下载地址:
 

+ 49 - 0
docs/prepare_data.md

@@ -0,0 +1,49 @@
+# 表格数据集准备
+
+表格数据集的图片由版面数据集切图得到,并经过页面旋转(Page Rotation)和倾斜校正(Skew Correction)预处理。
+
+表格数据集使用 PPOCRLabel 进行标注,标注流程请查看官方文档:[表格标注](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/PPOCRLabel/README_ch.md#22-%E8%A1%A8%E6%A0%BC%E6%A0%87%E6%B3%A8%E8%A7%86%E9%A2%91%E6%BC%94%E7%A4%BA)。
+
+## 数据集格式
+
+数据集为[PaddleOCR 表格识别模型数据集格式](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#11-%E6%95%B0%E6%8D%AE%E9%9B%86%E6%A0%BC%E5%BC%8F),包含表格结构和每个 Cell 的信息:
+
+```text
+{
+   'filename': PMC5755158_010_01.png,                               # 图像名
+   'html': {
+     'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]},     # 表格的HTML字符串
+     'cells': [
+       {
+         'tokens': ['P', 'a', 'd', 'd', 'l', 'e'],                  # 表格中的单个文本
+         'bbox': [x0, y0, x1, y1]                                   # 表格中的单个文本的坐标
+       }
+     ]
+   }
+}
+```
+
+## 下载数据集
+
+请将数据集下载到本地。数据集文件结构如下:
+
+```text
+table-dataset/
+├── artificial      # 人工合成的表格
+│   ├── all         # 全部图片
+│   ├── all.txt
+│   ├── test.txt
+│   └── train.txt
+├── conv.v16i       # 常规版面中的表格,切图自 https://app.roboflow.com/yili-gxczm/yili_layout/16
+│   ├── all         # 全部图片
+│   ├── all.txt
+│   ├── test.txt
+│   └── train.txt
+└── unconv.v7i      # 非常规版面中的表格,切图自 https://app.roboflow.com/yili-gxczm/yili_layout_non_rec_for_seg/7
+    ├── all         # 全部图片
+    ├── all.txt
+    ├── test.txt
+    └── train.txt
+```
+
+上面的目录结构里,`all` 图片文件夹的标注存储在 `all.txt` 中。从 `all.txt` 中分割出的训练标签存储在 `train.txt`,测试标签存储在 `test.txt` 中。

+ 69 - 0
docs/scripts/table_model.sh

@@ -0,0 +1,69 @@
+#!/bin/bash
+# shellcheck disable=SC2155
+
+set -eux
+
+readonly CUR_TIME=$(date "+%s")
+
+# SLANet_ch 模型是 PaddleOCR 目前最优的中文表格预训练模型
+# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md
+readonly MODEL_NAME="SLANet_ch"
+readonly MODEL_CONF="configs/table/${TABLE_NAME}.yml"
+
+# 推理参数
+readonly INFER_IMG_DIR="train_data/table-dataset/conv.v16i/all"
+
+edit_model() {
+    vim "${MODEL_CONF}"
+}
+
+train_model() {
+    python3 tools/train.py -c "${MODEL_CONF}"
+}
+
+train_model_distr() {
+    python3 \
+        -m paddle.distributed.launch \
+        --gpus '0,1,2,3,4,5,6,7' \
+        tools/train.py -c "${MODEL_CONF}"
+}
+
+export_model() {
+    python3 tools/export_model.py \
+        -c "${MODEL_CONF}" \
+        -o Global.pretrained_model="./output/${MODEL_NAME}/best_accuracy" \
+        Global.save_inference_dir="./inference/${MODEL_NAME}"
+}
+
+infer_model() {
+    python3 ppstructure/table/predict_structure.py \
+        --table_model_dir=inference/"${MODEL_NAME}" \
+        --rec_char_dict_path="./ppocr/utils/ppocr_keys_v1.txt" \
+        --table_char_dict_path="./ppocr/utils/dict/table_structure_dict_ch.txt" \
+        --image_dir="${INFER_IMG_DIR}" \
+        --output="inference_results/${MODEL_NAME}_${CUR_TIME}"
+}
+
+main() {
+    case "${1}" in
+    edit)
+        edit_model
+        ;;
+    train)
+        train_model
+        ;;
+    train_distr)
+        train_model_distr
+        ;;
+    export)
+        export_model
+        ;;
+    infer)
+        infer_model
+        ;;
+    *)
+        echo "Invalid option: ${1}"
+        exit 1
+        ;;
+    esac
+}

+ 108 - 0
docs/train_and_eval.md

@@ -0,0 +1,108 @@
+# 表格结构模型训练与评估
+
+表格识别流程中包含三个模型:表格结构预测模型,单行文本检测模型,单行文本识别模型。我们目前对表格结构预测模型进行了训练。<br>
+官方文档:[表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/table/README_ch.md)
+
+## 准备数据集
+
+请参考:[数据集准备](./prepare_data.md)
+
+## 准备环境
+
+克隆 PaddleOCR 仓库,进入仓库目录:
+
+```bash
+git clone --depth 1 https://github.com/PaddlePaddle/PaddleOCR.git
+cd PaddleOCR
+```
+
+PaddleOCR 训练数据的默认存储路径是 `PaddleOCR/train_data`。我们将数据集下载到本地后,可以拷贝数据集或创建软链接到对应目录:
+
+```bash
+cp -r /path/to/table-dataset ./train_data/table-dataset
+# 或者
+ln -sf /path/to/table-dataset ./train_data/table-dataset
+```
+
+请将我们的训练脚本 [table_model.sh](./scripts/table_model.sh) 拷贝至 `PaddleOCR/` 路径下。<br>
+PaddleOCR 对训练过程做了模块化,如果要训练不同的模型,我们只需要在脚本开头更换配置文件。
+
+## 表格结构预测模型训练与评估
+
+### 训练
+
+以我们目前使用的 SLANet 模型为例(官方文档:[表格识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B),配置文件:[SLANet_ch.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/configs/table/SLANet_ch.yml)),修改配置文件如下:
+
+```bash
+$ cat configs/table/SLANet_ch.yml
+Global:
+  use_gpu: True
+  # 修改训练轮数
+  epoch_num: 400
+  # 修改为实际的预训练模型文件
+  pretrained_model: ./pretrain_models/ch_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
+...
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  clip_norm: 5.0
+  lr:
+    # 修改学习率
+    learning_rate: 0.001
+...
+
+Train:
+  dataset:
+    name: PubTabDataSet
+    # 修改为实际训练集的目录
+    data_dir: train_data/table-dataset/artificial
+    # 修改为实际训练集的标签文件
+    label_file_list: [train_data/table-dataset/artificial/train.txt]
+...
+
+Eval:
+  dataset:
+    name: PubTabDataSet
+    # 修改为实际验证集的目录
+    data_dir: train_data/table-dataset/artificial/
+    # 修改为实际验证集的标签文件
+    label_file_list: [train_data/table-dataset/artificial/test.txt]
+...
+```
+
+其中,学习率 `learning_rate` (记为`lr`) 需要按运行时 `GPU卡数` (记为`GPU_number`) 和 `batch_size_per_card` (记为`batch_size`) 进行调整,公式为:
+
+**lr<sub>new</sub> = lr<sub>default</sub> \* (batch_size<sub>new</sub> \* GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> \* GPU_number<sub>default</sub>)**
+
+PaddleOCR 默认的配置文件对应 **batch_size<sub>default</sub>=8**,**GPU_number<sub>default</sub>=8**。
+
+更详细的参数调整说明,请参考官方文档:[模型微调](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/finetune.md)。更详细的配置项含义,请参考官方文档:[配置文件内容与生成](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/config.md)。
+
+训练模型:
+
+```bash
+# 单卡训练
+./table_model.sh train
+# 多卡训练
+./table_model.sh train_distr
+```
+
+导出模型:
+
+```bash
+./table_model.sh export
+```
+
+使用导出的模型推理:
+
+```bash
+./table_model.sh infer
+```
+
+更详细的模型训练,推理,部署说明请参考:[官方文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md)
+
+### 评估
+
+请参考官方文档的[评估方法](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#3-%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0%E4%B8%8E%E9%A2%84%E6%B5%8B)。