train_and_eval.md 3.7 KB

表格结构模型训练与评估

PaddleOCR 的表格识别流程中包含三个模型:表格结构预测模型,单行文本检测模型,单行文本识别模型。我们目前对表格结构预测模型进行了训练。

准备数据集

请参考:数据集准备

准备环境

克隆 PaddleOCR 仓库,进入仓库目录:

git clone --depth 1 https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR

安装依赖:

pip install -r requirements.txt

PaddleOCR 训练数据的默认存储路径是 PaddleOCR/train_data。我们将数据集下载到本地后,可以拷贝数据集或创建软链接到对应目录:

cp -r /path/to/table-dataset ./train_data/table-dataset
# 或者
ln -sf /path/to/table-dataset ./train_data/table-dataset

请将我们的训练脚本 table_model.sh 拷贝至 PaddleOCR/ 路径下。
PaddleOCR 对训练过程做了模块化,如果要训练不同的模型,我们只需要在脚本开头更换配置文件。

表格结构预测模型训练与评估

训练

以我们目前使用的 SLANet 模型为例(官方文档:表格识别模型,配置文件:SLANet_ch.yml),修改配置文件如下:

$ cat configs/table/SLANet_ch.yml
Global:
  use_gpu: True
  # 训练轮数
  epoch_num: 400
  # 预训练模型文件
  pretrained_model: ./pretrain_models/ch_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
...

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  clip_norm: 5.0
  lr:
    # 学习率
    learning_rate: 0.001
...

Train:
  dataset:
    name: PubTabDataSet
    # 训练集目录
    data_dir: train_data/table-dataset/artificial
    # 训练集标注文件
    label_file_list: [train_data/table-dataset/artificial/train.txt]
...

Eval:
  dataset:
    name: PubTabDataSet
    # 验证集目录
    data_dir: train_data/table-dataset/artificial/
    # 验证集标注文件
    label_file_list: [train_data/table-dataset/artificial/test.txt]
...

其中,学习率 learning_rate (记为lr) 需要按运行时 GPU卡数 (记为GPU_number) 和 batch_size_per_card (记为batch_size) 进行调整,公式为:

lrnew = lrdefault * (batch_sizenew * GPU_numbernew) / (batch_sizedefault * GPU_numberdefault)

PaddleOCR 默认的配置文件对应 batch_sizedefault=8GPU_numberdefault=8

更详细的参数调整说明,请参考官方文档:模型微调。更详细的配置项含义,请参考官方文档:配置文件内容与生成

训练模型:

# 单卡训练
./table_model.sh train
# 多卡训练
./table_model.sh train_distr

导出模型:

./table_model.sh export

使用导出的模型推理:

./table_model.sh infer

更详细的模型训练,推理,部署说明请参考:官方文档

评估

请参考官方文档的评估方法