Pārlūkot izejas kodu

feat: 添加表格识别策略, 可以选择使用结构预测的bbox代替文本检测的bbox

jingze_cheng 9 mēneši atpakaļ
vecāks
revīzija
4979c519e2
2 mainītis faili ar 399 papildinājumiem un 10 dzēšanām
  1. 384 0
      cores/table_engine.py
  2. 15 10
      server.py

+ 384 - 0
cores/table_engine.py

@@ -0,0 +1,384 @@
+import copy
+import time
+import os
+from pathlib import Path
+import logging
+import numpy as np
+
+import paddleocr
+from paddleocr.paddleocr import (
+    PPStructure,
+    parse_args,
+    check_gpu,
+    parse_lang,
+    get_model_config,
+    confirm_model_dir_url,
+    maybe_download,
+    check_img,
+    BASE_DIR,
+    SUPPORT_STRUCTURE_MODEL_VERSION,
+)
+from ppstructure.table.predict_table import (
+    TableSystem,
+    expand,
+    sorted_boxes,
+    logger,
+)
+from ppstructure.predict_system import LayoutPredictor, TextSystem
+
+
+class CustomPPStructure(PPStructure):
+    def __init__(self, **kwargs):
+        params = parse_args(mMain=False)
+        params.__dict__.update(**kwargs)
+        assert (
+            params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION
+        ), "structure_version must in {}, but get {}".format(
+            SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version
+        )
+        params.use_gpu = check_gpu(params.use_gpu)
+        params.mode = "structure"
+
+        if not params.show_log:
+            logger.setLevel(logging.INFO)
+        lang, det_lang = parse_lang(params.lang)
+        if lang == "ch":
+            table_lang = "ch"
+        else:
+            table_lang = "en"
+        if params.structure_version == "PP-Structure":
+            params.merge_no_span_structure = False
+
+        # init model dir
+        det_model_config = get_model_config(
+            "OCR", params.ocr_version, "det", det_lang
+        )
+        params.det_model_dir, det_url = confirm_model_dir_url(
+            params.det_model_dir,
+            os.path.join(BASE_DIR, "whl", "det", det_lang),
+            det_model_config["url"],
+        )
+        rec_model_config = get_model_config(
+            "OCR", params.ocr_version, "rec", lang
+        )
+        params.rec_model_dir, rec_url = confirm_model_dir_url(
+            params.rec_model_dir,
+            os.path.join(BASE_DIR, "whl", "rec", lang),
+            rec_model_config["url"],
+        )
+        table_model_config = get_model_config(
+            "STRUCTURE", params.structure_version, "table", table_lang
+        )
+        params.table_model_dir, table_url = confirm_model_dir_url(
+            params.table_model_dir,
+            os.path.join(BASE_DIR, "whl", "table"),
+            table_model_config["url"],
+        )
+        layout_model_config = get_model_config(
+            "STRUCTURE", params.structure_version, "layout", lang
+        )
+        params.layout_model_dir, layout_url = confirm_model_dir_url(
+            params.layout_model_dir,
+            os.path.join(BASE_DIR, "whl", "layout"),
+            layout_model_config["url"],
+        )
+        # download model
+        maybe_download(params.det_model_dir, det_url)
+        maybe_download(params.rec_model_dir, rec_url)
+        maybe_download(params.table_model_dir, table_url)
+        maybe_download(params.layout_model_dir, layout_url)
+
+        # ---------- Custom adjust start ------------ #
+        paddleocr_path = Path(paddleocr.__file__).parent
+        if params.rec_char_dict_path is None:
+            params.rec_char_dict_path = str(
+                paddleocr_path / rec_model_config["dict_path"]
+            )
+        if params.table_char_dict_path is None:
+            params.table_char_dict_path = str(
+                paddleocr_path / table_model_config["dict_path"]
+            )
+        if params.layout_dict_path is None:
+            params.layout_dict_path = str(
+                paddleocr_path / layout_model_config["dict_path"]
+            )
+        # ---------- Custom adjust end ------------- #
+        logger.debug(params)
+
+        # Code from ppstructure.predict_system.StructureSystem
+        args = params
+        self.mode = args.mode
+        self.recovery = args.recovery
+
+        self.image_orientation_predictor = None
+        if args.image_orientation:
+            import paddleclas
+
+            self.image_orientation_predictor = paddleclas.PaddleClas(
+                model_name="text_image_orientation"
+            )
+
+        if self.mode == "structure":
+            if not args.show_log:
+                logger.setLevel(logging.INFO)
+            if args.layout == False and args.ocr == True:
+                args.ocr = False
+                logger.warning(
+                    "When args.layout is false, args.ocr is automatically set to false"
+                )
+            args.drop_score = 0
+            # init model
+            self.layout_predictor = None
+            self.text_system = None
+            self.table_system = None
+            if args.layout:
+                self.layout_predictor = LayoutPredictor(args)
+                if args.ocr:
+                    self.text_system = TextSystem(args)
+            # ---------- Custom adjust start ------------ #
+            if args.table:
+                if self.text_system is not None:
+                    self.table_system = CustomTableSystem(
+                        args,
+                        self.text_system.text_detector,
+                        self.text_system.text_recognizer,
+                    )
+                else:
+                    self.table_system = CustomTableSystem(args)
+            # ---------- Custom adjust end ------------- #
+
+        elif self.mode == "kie":
+            from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
+
+            self.kie_predictor = SerRePredictor(args)
+
+    def __call__(
+        self,
+        img,
+        return_ocr_result_in_table=False,
+        img_idx=0,
+        prefer_table_cell_boxes=False,
+    ):
+        img = check_img(img)
+        res, _ = self._super_call(
+            img,
+            return_ocr_result_in_table,
+            img_idx=img_idx,
+            prefer_table_cell_boxes=prefer_table_cell_boxes,
+        )
+        return res
+
+    def _super_call(
+        self,
+        img,
+        return_ocr_result_in_table=False,
+        img_idx=0,
+        prefer_table_cell_boxes=False,
+    ):
+        time_dict = {
+            "image_orientation": 0,
+            "layout": 0,
+            "table": 0,
+            "table_match": 0,
+            "det": 0,
+            "rec": 0,
+            "kie": 0,
+            "all": 0,
+        }
+        start = time.time()
+        if self.image_orientation_predictor is not None:
+            tic = time.time()
+            cls_result = self.image_orientation_predictor.predict(
+                input_data=img
+            )
+            cls_res = next(cls_result)
+            angle = cls_res[0]["label_names"][0]
+            cv_rotate_code = {
+                "90": cv2.ROTATE_90_COUNTERCLOCKWISE,
+                "180": cv2.ROTATE_180,
+                "270": cv2.ROTATE_90_CLOCKWISE,
+            }
+            if angle in cv_rotate_code:
+                img = cv2.rotate(img, cv_rotate_code[angle])
+            toc = time.time()
+            time_dict["image_orientation"] = toc - tic
+        if self.mode == "structure":
+            ori_im = img.copy()
+            if self.layout_predictor is not None:
+                layout_res, elapse = self.layout_predictor(img)
+                time_dict["layout"] += elapse
+            else:
+                h, w = ori_im.shape[:2]
+                layout_res = [dict(bbox=None, label="table")]
+            res_list = []
+            for region in layout_res:
+                res = ""
+                if region["bbox"] is not None:
+                    x1, y1, x2, y2 = region["bbox"]
+                    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+                    roi_img = ori_im[y1:y2, x1:x2, :]
+                else:
+                    x1, y1, x2, y2 = 0, 0, w, h
+                    roi_img = ori_im
+                if region["label"] == "table":
+                    if self.table_system is not None:
+                        # ---------- Custom adjust start ------------ #
+                        res, table_time_dict = self.table_system(
+                            roi_img,
+                            return_ocr_result_in_table,
+                            prefer_table_cell_boxes,
+                        )
+                        # ---------- Custom adjust end ------------- #
+                        time_dict["table"] += table_time_dict["table"]
+                        time_dict["table_match"] += table_time_dict["match"]
+                        time_dict["det"] += table_time_dict["det"]
+                        time_dict["rec"] += table_time_dict["rec"]
+                else:
+                    if self.text_system is not None:
+                        if self.recovery:
+                            wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
+                            wht_im[y1:y2, x1:x2, :] = roi_img
+                            filter_boxes, filter_rec_res, ocr_time_dict = (
+                                self.text_system(wht_im)
+                            )
+                        else:
+                            filter_boxes, filter_rec_res, ocr_time_dict = (
+                                self.text_system(roi_img)
+                            )
+                        time_dict["det"] += ocr_time_dict["det"]
+                        time_dict["rec"] += ocr_time_dict["rec"]
+
+                        # remove style char,
+                        # when using the recognition model trained on the PubtabNet dataset,
+                        # it will recognize the text format in the table, such as <b>
+                        style_token = [
+                            "<strike>",
+                            "<strike>",
+                            "<sup>",
+                            "</sub>",
+                            "<b>",
+                            "</b>",
+                            "<sub>",
+                            "</sup>",
+                            "<overline>",
+                            "</overline>",
+                            "<underline>",
+                            "</underline>",
+                            "<i>",
+                            "</i>",
+                        ]
+                        res = []
+                        for box, rec_res in zip(filter_boxes, filter_rec_res):
+                            rec_str, rec_conf = rec_res
+                            for token in style_token:
+                                if token in rec_str:
+                                    rec_str = rec_str.replace(token, "")
+                            if not self.recovery:
+                                box += [x1, y1]
+                            res.append(
+                                {
+                                    "text": rec_str,
+                                    "confidence": float(rec_conf),
+                                    "text_region": box.tolist(),
+                                }
+                            )
+                res_list.append(
+                    {
+                        "type": region["label"].lower(),
+                        "bbox": [x1, y1, x2, y2],
+                        "img": roi_img,
+                        "res": res,
+                        "img_idx": img_idx,
+                    }
+                )
+            end = time.time()
+            time_dict["all"] = end - start
+            return res_list, time_dict
+        elif self.mode == "kie":
+            re_res, elapse = self.kie_predictor(img)
+            time_dict["kie"] = elapse
+            time_dict["all"] = elapse
+            return re_res[0], time_dict
+        return None, None
+
+
+class CustomTableSystem(TableSystem):
+    def __call__(
+        self,
+        img,
+        return_ocr_result_in_table=False,
+        prefer_table_cell_boxes=False,
+    ):
+        result = dict()
+        time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0}
+        start = time.time()
+        structure_res, elapse = self._structure(copy.deepcopy(img))
+        result["cell_bbox"] = structure_res[1].tolist()
+        time_dict["table"] = elapse
+
+        # ---------- Custom adjust start ------------ #
+        if prefer_table_cell_boxes:
+            cell_boxes = (
+                np.array([[[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[6], b[7]]] for b in structure_res[1]])
+                .astype(int)
+                .astype(float)
+            )
+        else:
+            cell_boxes = None
+        dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
+            copy.deepcopy(img), cell_boxes=cell_boxes
+        )
+        # ---------- Custom adjust end ------------- #
+        time_dict["det"] = det_elapse
+        time_dict["rec"] = rec_elapse
+
+        if return_ocr_result_in_table:
+            result["boxes"] = [x.tolist() for x in dt_boxes]
+            result["rec_res"] = rec_res
+
+        tic = time.time()
+        pred_html = self.match(structure_res, dt_boxes, rec_res)
+        toc = time.time()
+        time_dict["match"] = toc - tic
+        result["html"] = pred_html
+        end = time.time()
+        time_dict["all"] = end - start
+        return result, time_dict
+
+    def _ocr(self, img, cell_boxes=None):
+        h, w = img.shape[:2]
+        # ---------- Custom adjust start ------------ #
+        if cell_boxes is None:
+            dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
+        else:
+            logger.debug("use cell_boxes instead of dt_boxes")
+            dt_boxes, det_elapse = cell_boxes, 0.0
+        # ---------- Custom adjust end ------------- #
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        r_boxes = []
+        for box in dt_boxes:
+            x_min = max(0, box[:, 0].min() - 1)
+            x_max = min(w, box[:, 0].max() + 1)
+            y_min = max(0, box[:, 1].min() - 1)
+            y_max = min(h, box[:, 1].max() + 1)
+            box = [x_min, y_min, x_max, y_max]
+            r_boxes.append(box)
+        dt_boxes = np.array(r_boxes)
+        logger.debug(
+            "dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse)
+        )
+        if dt_boxes is None:
+            return None, None
+
+        img_crop_list = []
+        for i in range(len(dt_boxes)):
+            det_box = dt_boxes[i]
+            x0, y0, x1, y1 = expand(2, det_box, img.shape)
+            text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :]
+            img_crop_list.append(text_rect)
+        rec_res, rec_elapse = self.text_recognizer(img_crop_list)
+        logger.debug(
+            "rec_res num  : {}, elapse : {}".format(len(rec_res), rec_elapse)
+        )
+        return dt_boxes, rec_res, det_elapse, rec_elapse

+ 15 - 10
server.py

@@ -1,8 +1,8 @@
 # -*- coding: UTF-8 -*-
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-from paddleocr import PPStructure
+from pydantic import BaseModel, Field
+from cores.table_engine import CustomPPStructure
 from sx_utils.sxweb import *
 from sx_utils.sximage import *
 import threading
@@ -28,7 +28,7 @@ app.add_middleware(
 
 table_engine_lock = threading.Lock()
 # 表格识别模型
-table_engine = PPStructure(layout=False,
+table_engine = CustomPPStructure(layout=False,
                            table=True,
                            use_gpu=True,
                            show_log=True,
@@ -138,13 +138,14 @@ def get_zero_degree_image(img):
     return img
 
 
-def table_res(im, ROTATE=-1):
+def table_res(im, ROTATE=-1, prefer_cell=False):
     """
     获取图像中表格的识别结果和HTML字符串。
 
     Parameters:
         im (np.ndarray): 输入的图像数组。
         ROTATE (int): 旋转角度,默认为-1。
+        prefer_cell (bool): 是否使用cell_boxes替代dt_boxes,默认为否。
 
     Returns:
         Tuple: 表格识别结果和HTML字符串。
@@ -153,7 +154,7 @@ def table_res(im, ROTATE=-1):
     img = get_zero_degree_image(im)
     try:
         table_engine_lock.acquire()
-        res = table_engine(img)
+        res = table_engine(img, prefer_table_cell_boxes=prefer_cell)
     finally:
         table_engine_lock.release()
     html = res[0]['res']['html']
@@ -163,6 +164,10 @@ def table_res(im, ROTATE=-1):
 class TableInfo(BaseModel):
     image: str
     det: str
+    prefer_cell: bool = Field(
+        default=False, 
+        description="是否使用cell_boxes替代dt_boxes"
+    )
 
 
 @app.get("/ping")
@@ -178,25 +183,25 @@ def ping():
 
 @app.post("/ocr_system/table")
 @web_try()
-def table(image: TableInfo):
+def table(info: TableInfo):
     """
     对图像中的表格进行识别并返回HTML字符串。
 
     Parameters:
-        image (TableInfo): 输入的图像信息。
+        info (TableInfo): 输入的图像信息。
 
     Returns:
         dict: 包含HTML字符串的字典。
     """
     # 转换图片格式
-    img = base64_to_np(image.image)
+    img = base64_to_np(info.image)
     # 进行表格识别
-    res, html = table_res(img)
+    res, html = table_res(img, prefer_cell=info.prefer_cell)
     # 创建Table实例
     table = Table(html, img)
     # 效果不好则重新识别
     if table.check_html():
-        res, html = table_res(table.img)
+        res, html = table_res(table.img, prefer_cell=info.prefer_cell)
 
     if html:
         post_handler = PostHandler(html)