|
@@ -0,0 +1,114 @@
|
|
|
+# 表格结构模型训练与评估
|
|
|
+
|
|
|
+表格识别流程中包含三个模型:表格结构预测模型,单行文本检测模型,单行文本识别模型。我们目前对表格结构预测模型进行了训练。<br>
|
|
|
+官方文档:[表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/table/README_ch.md)
|
|
|
+
|
|
|
+## 准备数据集
|
|
|
+
|
|
|
+请参考:[数据集准备](./prepare_data.md)
|
|
|
+
|
|
|
+## 准备环境
|
|
|
+
|
|
|
+克隆 PaddleOCR 仓库,进入仓库目录:
|
|
|
+
|
|
|
+```bash
|
|
|
+git clone --depth 1 https://github.com/PaddlePaddle/PaddleOCR.git
|
|
|
+cd PaddleOCR
|
|
|
+```
|
|
|
+
|
|
|
+安装依赖:
|
|
|
+
|
|
|
+```bash
|
|
|
+pip install -r requirements.txt
|
|
|
+```
|
|
|
+
|
|
|
+PaddleOCR 训练数据的默认存储路径是 `PaddleOCR/train_data`。我们将数据集下载到本地后,可以拷贝数据集或创建软链接到对应目录:
|
|
|
+
|
|
|
+```bash
|
|
|
+cp -r /path/to/table-dataset ./train_data/table-dataset
|
|
|
+# 或者
|
|
|
+ln -sf /path/to/table-dataset ./train_data/table-dataset
|
|
|
+```
|
|
|
+
|
|
|
+请将我们的训练脚本 [table_model.sh](./scripts/table_model.sh) 拷贝至 `PaddleOCR/` 路径下。<br>
|
|
|
+PaddleOCR 对训练过程做了模块化,如果要训练不同的模型,我们只需要在脚本开头更换配置文件。
|
|
|
+
|
|
|
+## 表格结构预测模型训练与评估
|
|
|
+
|
|
|
+### 训练
|
|
|
+
|
|
|
+以我们目前使用的 SLANet 模型为例(官方文档:[表格识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/models_list.md#22-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B),配置文件:[SLANet_ch.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/configs/table/SLANet_ch.yml)),修改配置文件如下:
|
|
|
+
|
|
|
+```bash
|
|
|
+$ cat configs/table/SLANet_ch.yml
|
|
|
+Global:
|
|
|
+ use_gpu: True
|
|
|
+ # 训练轮数
|
|
|
+ epoch_num: 400
|
|
|
+ # 预训练模型文件
|
|
|
+ pretrained_model: ./pretrain_models/ch_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
|
|
|
+...
|
|
|
+
|
|
|
+Optimizer:
|
|
|
+ name: Adam
|
|
|
+ beta1: 0.9
|
|
|
+ beta2: 0.999
|
|
|
+ clip_norm: 5.0
|
|
|
+ lr:
|
|
|
+ # 学习率
|
|
|
+ learning_rate: 0.001
|
|
|
+...
|
|
|
+
|
|
|
+Train:
|
|
|
+ dataset:
|
|
|
+ name: PubTabDataSet
|
|
|
+ # 训练集目录
|
|
|
+ data_dir: train_data/table-dataset/artificial
|
|
|
+ # 训练集标注文件
|
|
|
+ label_file_list: [train_data/table-dataset/artificial/train.txt]
|
|
|
+...
|
|
|
+
|
|
|
+Eval:
|
|
|
+ dataset:
|
|
|
+ name: PubTabDataSet
|
|
|
+ # 验证集目录
|
|
|
+ data_dir: train_data/table-dataset/artificial/
|
|
|
+ # 验证集标注文件
|
|
|
+ label_file_list: [train_data/table-dataset/artificial/test.txt]
|
|
|
+...
|
|
|
+```
|
|
|
+
|
|
|
+其中,学习率 `learning_rate` (记为`lr`) 需要按运行时 `GPU卡数` (记为`GPU_number`) 和 `batch_size_per_card` (记为`batch_size`) 进行调整,公式为:
|
|
|
+
|
|
|
+**lr<sub>new</sub> = lr<sub>default</sub> \* (batch_size<sub>new</sub> \* GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> \* GPU_number<sub>default</sub>)**
|
|
|
+
|
|
|
+PaddleOCR 默认的配置文件对应 **batch_size<sub>default</sub>=8**,**GPU_number<sub>default</sub>=8**。
|
|
|
+
|
|
|
+更详细的参数调整说明,请参考官方文档:[模型微调](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/finetune.md)。更详细的配置项含义,请参考官方文档:[配置文件内容与生成](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/config.md)。
|
|
|
+
|
|
|
+训练模型:
|
|
|
+
|
|
|
+```bash
|
|
|
+# 单卡训练
|
|
|
+./table_model.sh train
|
|
|
+# 多卡训练
|
|
|
+./table_model.sh train_distr
|
|
|
+```
|
|
|
+
|
|
|
+导出模型:
|
|
|
+
|
|
|
+```bash
|
|
|
+./table_model.sh export
|
|
|
+```
|
|
|
+
|
|
|
+使用导出的模型推理:
|
|
|
+
|
|
|
+```bash
|
|
|
+./table_model.sh infer
|
|
|
+```
|
|
|
+
|
|
|
+更详细的模型训练,推理,部署说明请参考:[官方文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md)
|
|
|
+
|
|
|
+### 评估
|
|
|
+
|
|
|
+请参考官方文档的[评估方法](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/table_recognition.md#3-%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0%E4%B8%8E%E9%A2%84%E6%B5%8B)。
|