|
@@ -0,0 +1,384 @@
|
|
|
+import copy
|
|
|
+import time
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
+import logging
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+import paddleocr
|
|
|
+from paddleocr.paddleocr import (
|
|
|
+ PPStructure,
|
|
|
+ parse_args,
|
|
|
+ check_gpu,
|
|
|
+ parse_lang,
|
|
|
+ get_model_config,
|
|
|
+ confirm_model_dir_url,
|
|
|
+ maybe_download,
|
|
|
+ check_img,
|
|
|
+ BASE_DIR,
|
|
|
+ SUPPORT_STRUCTURE_MODEL_VERSION,
|
|
|
+)
|
|
|
+from ppstructure.table.predict_table import (
|
|
|
+ TableSystem,
|
|
|
+ expand,
|
|
|
+ sorted_boxes,
|
|
|
+ logger,
|
|
|
+)
|
|
|
+from ppstructure.predict_system import LayoutPredictor, TextSystem
|
|
|
+
|
|
|
+
|
|
|
+class CustomPPStructure(PPStructure):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ params = parse_args(mMain=False)
|
|
|
+ params.__dict__.update(**kwargs)
|
|
|
+ assert (
|
|
|
+ params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION
|
|
|
+ ), "structure_version must in {}, but get {}".format(
|
|
|
+ SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version
|
|
|
+ )
|
|
|
+ params.use_gpu = check_gpu(params.use_gpu)
|
|
|
+ params.mode = "structure"
|
|
|
+
|
|
|
+ if not params.show_log:
|
|
|
+ logger.setLevel(logging.INFO)
|
|
|
+ lang, det_lang = parse_lang(params.lang)
|
|
|
+ if lang == "ch":
|
|
|
+ table_lang = "ch"
|
|
|
+ else:
|
|
|
+ table_lang = "en"
|
|
|
+ if params.structure_version == "PP-Structure":
|
|
|
+ params.merge_no_span_structure = False
|
|
|
+
|
|
|
+ # init model dir
|
|
|
+ det_model_config = get_model_config(
|
|
|
+ "OCR", params.ocr_version, "det", det_lang
|
|
|
+ )
|
|
|
+ params.det_model_dir, det_url = confirm_model_dir_url(
|
|
|
+ params.det_model_dir,
|
|
|
+ os.path.join(BASE_DIR, "whl", "det", det_lang),
|
|
|
+ det_model_config["url"],
|
|
|
+ )
|
|
|
+ rec_model_config = get_model_config(
|
|
|
+ "OCR", params.ocr_version, "rec", lang
|
|
|
+ )
|
|
|
+ params.rec_model_dir, rec_url = confirm_model_dir_url(
|
|
|
+ params.rec_model_dir,
|
|
|
+ os.path.join(BASE_DIR, "whl", "rec", lang),
|
|
|
+ rec_model_config["url"],
|
|
|
+ )
|
|
|
+ table_model_config = get_model_config(
|
|
|
+ "STRUCTURE", params.structure_version, "table", table_lang
|
|
|
+ )
|
|
|
+ params.table_model_dir, table_url = confirm_model_dir_url(
|
|
|
+ params.table_model_dir,
|
|
|
+ os.path.join(BASE_DIR, "whl", "table"),
|
|
|
+ table_model_config["url"],
|
|
|
+ )
|
|
|
+ layout_model_config = get_model_config(
|
|
|
+ "STRUCTURE", params.structure_version, "layout", lang
|
|
|
+ )
|
|
|
+ params.layout_model_dir, layout_url = confirm_model_dir_url(
|
|
|
+ params.layout_model_dir,
|
|
|
+ os.path.join(BASE_DIR, "whl", "layout"),
|
|
|
+ layout_model_config["url"],
|
|
|
+ )
|
|
|
+ # download model
|
|
|
+ maybe_download(params.det_model_dir, det_url)
|
|
|
+ maybe_download(params.rec_model_dir, rec_url)
|
|
|
+ maybe_download(params.table_model_dir, table_url)
|
|
|
+ maybe_download(params.layout_model_dir, layout_url)
|
|
|
+
|
|
|
+ # ---------- Custom adjust start ------------ #
|
|
|
+ paddleocr_path = Path(paddleocr.__file__).parent
|
|
|
+ if params.rec_char_dict_path is None:
|
|
|
+ params.rec_char_dict_path = str(
|
|
|
+ paddleocr_path / rec_model_config["dict_path"]
|
|
|
+ )
|
|
|
+ if params.table_char_dict_path is None:
|
|
|
+ params.table_char_dict_path = str(
|
|
|
+ paddleocr_path / table_model_config["dict_path"]
|
|
|
+ )
|
|
|
+ if params.layout_dict_path is None:
|
|
|
+ params.layout_dict_path = str(
|
|
|
+ paddleocr_path / layout_model_config["dict_path"]
|
|
|
+ )
|
|
|
+ # ---------- Custom adjust end ------------- #
|
|
|
+ logger.debug(params)
|
|
|
+
|
|
|
+ # Code from ppstructure.predict_system.StructureSystem
|
|
|
+ args = params
|
|
|
+ self.mode = args.mode
|
|
|
+ self.recovery = args.recovery
|
|
|
+
|
|
|
+ self.image_orientation_predictor = None
|
|
|
+ if args.image_orientation:
|
|
|
+ import paddleclas
|
|
|
+
|
|
|
+ self.image_orientation_predictor = paddleclas.PaddleClas(
|
|
|
+ model_name="text_image_orientation"
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.mode == "structure":
|
|
|
+ if not args.show_log:
|
|
|
+ logger.setLevel(logging.INFO)
|
|
|
+ if args.layout == False and args.ocr == True:
|
|
|
+ args.ocr = False
|
|
|
+ logger.warning(
|
|
|
+ "When args.layout is false, args.ocr is automatically set to false"
|
|
|
+ )
|
|
|
+ args.drop_score = 0
|
|
|
+ # init model
|
|
|
+ self.layout_predictor = None
|
|
|
+ self.text_system = None
|
|
|
+ self.table_system = None
|
|
|
+ if args.layout:
|
|
|
+ self.layout_predictor = LayoutPredictor(args)
|
|
|
+ if args.ocr:
|
|
|
+ self.text_system = TextSystem(args)
|
|
|
+ # ---------- Custom adjust start ------------ #
|
|
|
+ if args.table:
|
|
|
+ if self.text_system is not None:
|
|
|
+ self.table_system = CustomTableSystem(
|
|
|
+ args,
|
|
|
+ self.text_system.text_detector,
|
|
|
+ self.text_system.text_recognizer,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.table_system = CustomTableSystem(args)
|
|
|
+ # ---------- Custom adjust end ------------- #
|
|
|
+
|
|
|
+ elif self.mode == "kie":
|
|
|
+ from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
|
|
|
+
|
|
|
+ self.kie_predictor = SerRePredictor(args)
|
|
|
+
|
|
|
+ def __call__(
|
|
|
+ self,
|
|
|
+ img,
|
|
|
+ return_ocr_result_in_table=False,
|
|
|
+ img_idx=0,
|
|
|
+ prefer_table_cell_boxes=False,
|
|
|
+ ):
|
|
|
+ img = check_img(img)
|
|
|
+ res, _ = self._super_call(
|
|
|
+ img,
|
|
|
+ return_ocr_result_in_table,
|
|
|
+ img_idx=img_idx,
|
|
|
+ prefer_table_cell_boxes=prefer_table_cell_boxes,
|
|
|
+ )
|
|
|
+ return res
|
|
|
+
|
|
|
+ def _super_call(
|
|
|
+ self,
|
|
|
+ img,
|
|
|
+ return_ocr_result_in_table=False,
|
|
|
+ img_idx=0,
|
|
|
+ prefer_table_cell_boxes=False,
|
|
|
+ ):
|
|
|
+ time_dict = {
|
|
|
+ "image_orientation": 0,
|
|
|
+ "layout": 0,
|
|
|
+ "table": 0,
|
|
|
+ "table_match": 0,
|
|
|
+ "det": 0,
|
|
|
+ "rec": 0,
|
|
|
+ "kie": 0,
|
|
|
+ "all": 0,
|
|
|
+ }
|
|
|
+ start = time.time()
|
|
|
+ if self.image_orientation_predictor is not None:
|
|
|
+ tic = time.time()
|
|
|
+ cls_result = self.image_orientation_predictor.predict(
|
|
|
+ input_data=img
|
|
|
+ )
|
|
|
+ cls_res = next(cls_result)
|
|
|
+ angle = cls_res[0]["label_names"][0]
|
|
|
+ cv_rotate_code = {
|
|
|
+ "90": cv2.ROTATE_90_COUNTERCLOCKWISE,
|
|
|
+ "180": cv2.ROTATE_180,
|
|
|
+ "270": cv2.ROTATE_90_CLOCKWISE,
|
|
|
+ }
|
|
|
+ if angle in cv_rotate_code:
|
|
|
+ img = cv2.rotate(img, cv_rotate_code[angle])
|
|
|
+ toc = time.time()
|
|
|
+ time_dict["image_orientation"] = toc - tic
|
|
|
+ if self.mode == "structure":
|
|
|
+ ori_im = img.copy()
|
|
|
+ if self.layout_predictor is not None:
|
|
|
+ layout_res, elapse = self.layout_predictor(img)
|
|
|
+ time_dict["layout"] += elapse
|
|
|
+ else:
|
|
|
+ h, w = ori_im.shape[:2]
|
|
|
+ layout_res = [dict(bbox=None, label="table")]
|
|
|
+ res_list = []
|
|
|
+ for region in layout_res:
|
|
|
+ res = ""
|
|
|
+ if region["bbox"] is not None:
|
|
|
+ x1, y1, x2, y2 = region["bbox"]
|
|
|
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
|
+ roi_img = ori_im[y1:y2, x1:x2, :]
|
|
|
+ else:
|
|
|
+ x1, y1, x2, y2 = 0, 0, w, h
|
|
|
+ roi_img = ori_im
|
|
|
+ if region["label"] == "table":
|
|
|
+ if self.table_system is not None:
|
|
|
+ # ---------- Custom adjust start ------------ #
|
|
|
+ res, table_time_dict = self.table_system(
|
|
|
+ roi_img,
|
|
|
+ return_ocr_result_in_table,
|
|
|
+ prefer_table_cell_boxes,
|
|
|
+ )
|
|
|
+ # ---------- Custom adjust end ------------- #
|
|
|
+ time_dict["table"] += table_time_dict["table"]
|
|
|
+ time_dict["table_match"] += table_time_dict["match"]
|
|
|
+ time_dict["det"] += table_time_dict["det"]
|
|
|
+ time_dict["rec"] += table_time_dict["rec"]
|
|
|
+ else:
|
|
|
+ if self.text_system is not None:
|
|
|
+ if self.recovery:
|
|
|
+ wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
|
|
|
+ wht_im[y1:y2, x1:x2, :] = roi_img
|
|
|
+ filter_boxes, filter_rec_res, ocr_time_dict = (
|
|
|
+ self.text_system(wht_im)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ filter_boxes, filter_rec_res, ocr_time_dict = (
|
|
|
+ self.text_system(roi_img)
|
|
|
+ )
|
|
|
+ time_dict["det"] += ocr_time_dict["det"]
|
|
|
+ time_dict["rec"] += ocr_time_dict["rec"]
|
|
|
+
|
|
|
+ # remove style char,
|
|
|
+ # when using the recognition model trained on the PubtabNet dataset,
|
|
|
+ # it will recognize the text format in the table, such as <b>
|
|
|
+ style_token = [
|
|
|
+ "<strike>",
|
|
|
+ "<strike>",
|
|
|
+ "<sup>",
|
|
|
+ "</sub>",
|
|
|
+ "<b>",
|
|
|
+ "</b>",
|
|
|
+ "<sub>",
|
|
|
+ "</sup>",
|
|
|
+ "<overline>",
|
|
|
+ "</overline>",
|
|
|
+ "<underline>",
|
|
|
+ "</underline>",
|
|
|
+ "<i>",
|
|
|
+ "</i>",
|
|
|
+ ]
|
|
|
+ res = []
|
|
|
+ for box, rec_res in zip(filter_boxes, filter_rec_res):
|
|
|
+ rec_str, rec_conf = rec_res
|
|
|
+ for token in style_token:
|
|
|
+ if token in rec_str:
|
|
|
+ rec_str = rec_str.replace(token, "")
|
|
|
+ if not self.recovery:
|
|
|
+ box += [x1, y1]
|
|
|
+ res.append(
|
|
|
+ {
|
|
|
+ "text": rec_str,
|
|
|
+ "confidence": float(rec_conf),
|
|
|
+ "text_region": box.tolist(),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ res_list.append(
|
|
|
+ {
|
|
|
+ "type": region["label"].lower(),
|
|
|
+ "bbox": [x1, y1, x2, y2],
|
|
|
+ "img": roi_img,
|
|
|
+ "res": res,
|
|
|
+ "img_idx": img_idx,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ end = time.time()
|
|
|
+ time_dict["all"] = end - start
|
|
|
+ return res_list, time_dict
|
|
|
+ elif self.mode == "kie":
|
|
|
+ re_res, elapse = self.kie_predictor(img)
|
|
|
+ time_dict["kie"] = elapse
|
|
|
+ time_dict["all"] = elapse
|
|
|
+ return re_res[0], time_dict
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+class CustomTableSystem(TableSystem):
|
|
|
+ def __call__(
|
|
|
+ self,
|
|
|
+ img,
|
|
|
+ return_ocr_result_in_table=False,
|
|
|
+ prefer_table_cell_boxes=False,
|
|
|
+ ):
|
|
|
+ result = dict()
|
|
|
+ time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0}
|
|
|
+ start = time.time()
|
|
|
+ structure_res, elapse = self._structure(copy.deepcopy(img))
|
|
|
+ result["cell_bbox"] = structure_res[1].tolist()
|
|
|
+ time_dict["table"] = elapse
|
|
|
+
|
|
|
+ # ---------- Custom adjust start ------------ #
|
|
|
+ if prefer_table_cell_boxes:
|
|
|
+ cell_boxes = (
|
|
|
+ np.array([[[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[6], b[7]]] for b in structure_res[1]])
|
|
|
+ .astype(int)
|
|
|
+ .astype(float)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cell_boxes = None
|
|
|
+ dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
|
|
|
+ copy.deepcopy(img), cell_boxes=cell_boxes
|
|
|
+ )
|
|
|
+ # ---------- Custom adjust end ------------- #
|
|
|
+ time_dict["det"] = det_elapse
|
|
|
+ time_dict["rec"] = rec_elapse
|
|
|
+
|
|
|
+ if return_ocr_result_in_table:
|
|
|
+ result["boxes"] = [x.tolist() for x in dt_boxes]
|
|
|
+ result["rec_res"] = rec_res
|
|
|
+
|
|
|
+ tic = time.time()
|
|
|
+ pred_html = self.match(structure_res, dt_boxes, rec_res)
|
|
|
+ toc = time.time()
|
|
|
+ time_dict["match"] = toc - tic
|
|
|
+ result["html"] = pred_html
|
|
|
+ end = time.time()
|
|
|
+ time_dict["all"] = end - start
|
|
|
+ return result, time_dict
|
|
|
+
|
|
|
+ def _ocr(self, img, cell_boxes=None):
|
|
|
+ h, w = img.shape[:2]
|
|
|
+ # ---------- Custom adjust start ------------ #
|
|
|
+ if cell_boxes is None:
|
|
|
+ dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
|
|
|
+ else:
|
|
|
+ logger.debug("use cell_boxes instead of dt_boxes")
|
|
|
+ dt_boxes, det_elapse = cell_boxes, 0.0
|
|
|
+ # ---------- Custom adjust end ------------- #
|
|
|
+ dt_boxes = sorted_boxes(dt_boxes)
|
|
|
+
|
|
|
+ r_boxes = []
|
|
|
+ for box in dt_boxes:
|
|
|
+ x_min = max(0, box[:, 0].min() - 1)
|
|
|
+ x_max = min(w, box[:, 0].max() + 1)
|
|
|
+ y_min = max(0, box[:, 1].min() - 1)
|
|
|
+ y_max = min(h, box[:, 1].max() + 1)
|
|
|
+ box = [x_min, y_min, x_max, y_max]
|
|
|
+ r_boxes.append(box)
|
|
|
+ dt_boxes = np.array(r_boxes)
|
|
|
+ logger.debug(
|
|
|
+ "dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse)
|
|
|
+ )
|
|
|
+ if dt_boxes is None:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ img_crop_list = []
|
|
|
+ for i in range(len(dt_boxes)):
|
|
|
+ det_box = dt_boxes[i]
|
|
|
+ x0, y0, x1, y1 = expand(2, det_box, img.shape)
|
|
|
+ text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :]
|
|
|
+ img_crop_list.append(text_rect)
|
|
|
+ rec_res, rec_elapse = self.text_recognizer(img_crop_list)
|
|
|
+ logger.debug(
|
|
|
+ "rec_res num : {}, elapse : {}".format(len(rec_res), rec_elapse)
|
|
|
+ )
|
|
|
+ return dt_boxes, rec_res, det_elapse, rec_elapse
|