123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384 |
- import copy
- import time
- import os
- from pathlib import Path
- import logging
- import numpy as np
- import paddleocr
- from paddleocr.paddleocr import (
- PPStructure,
- parse_args,
- check_gpu,
- parse_lang,
- get_model_config,
- confirm_model_dir_url,
- maybe_download,
- check_img,
- BASE_DIR,
- SUPPORT_STRUCTURE_MODEL_VERSION,
- )
- from ppstructure.table.predict_table import (
- TableSystem,
- expand,
- sorted_boxes,
- logger,
- )
- from ppstructure.predict_system import LayoutPredictor, TextSystem
- class CustomPPStructure(PPStructure):
- def __init__(self, **kwargs):
- params = parse_args(mMain=False)
- params.__dict__.update(**kwargs)
- assert (
- params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION
- ), "structure_version must in {}, but get {}".format(
- SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version
- )
- params.use_gpu = check_gpu(params.use_gpu)
- params.mode = "structure"
- if not params.show_log:
- logger.setLevel(logging.INFO)
- lang, det_lang = parse_lang(params.lang)
- if lang == "ch":
- table_lang = "ch"
- else:
- table_lang = "en"
- if params.structure_version == "PP-Structure":
- params.merge_no_span_structure = False
- # init model dir
- det_model_config = get_model_config(
- "OCR", params.ocr_version, "det", det_lang
- )
- params.det_model_dir, det_url = confirm_model_dir_url(
- params.det_model_dir,
- os.path.join(BASE_DIR, "whl", "det", det_lang),
- det_model_config["url"],
- )
- rec_model_config = get_model_config(
- "OCR", params.ocr_version, "rec", lang
- )
- params.rec_model_dir, rec_url = confirm_model_dir_url(
- params.rec_model_dir,
- os.path.join(BASE_DIR, "whl", "rec", lang),
- rec_model_config["url"],
- )
- table_model_config = get_model_config(
- "STRUCTURE", params.structure_version, "table", table_lang
- )
- params.table_model_dir, table_url = confirm_model_dir_url(
- params.table_model_dir,
- os.path.join(BASE_DIR, "whl", "table"),
- table_model_config["url"],
- )
- layout_model_config = get_model_config(
- "STRUCTURE", params.structure_version, "layout", lang
- )
- params.layout_model_dir, layout_url = confirm_model_dir_url(
- params.layout_model_dir,
- os.path.join(BASE_DIR, "whl", "layout"),
- layout_model_config["url"],
- )
- # download model
- maybe_download(params.det_model_dir, det_url)
- maybe_download(params.rec_model_dir, rec_url)
- maybe_download(params.table_model_dir, table_url)
- maybe_download(params.layout_model_dir, layout_url)
- # ---------- Custom adjust start ------------ #
- paddleocr_path = Path(paddleocr.__file__).parent
- if params.rec_char_dict_path is None:
- params.rec_char_dict_path = str(
- paddleocr_path / rec_model_config["dict_path"]
- )
- if params.table_char_dict_path is None:
- params.table_char_dict_path = str(
- paddleocr_path / table_model_config["dict_path"]
- )
- if params.layout_dict_path is None:
- params.layout_dict_path = str(
- paddleocr_path / layout_model_config["dict_path"]
- )
- # ---------- Custom adjust end ------------- #
- logger.debug(params)
- # Code from ppstructure.predict_system.StructureSystem
- args = params
- self.mode = args.mode
- self.recovery = args.recovery
- self.image_orientation_predictor = None
- if args.image_orientation:
- import paddleclas
- self.image_orientation_predictor = paddleclas.PaddleClas(
- model_name="text_image_orientation"
- )
- if self.mode == "structure":
- if not args.show_log:
- logger.setLevel(logging.INFO)
- if args.layout == False and args.ocr == True:
- args.ocr = False
- logger.warning(
- "When args.layout is false, args.ocr is automatically set to false"
- )
- args.drop_score = 0
- # init model
- self.layout_predictor = None
- self.text_system = None
- self.table_system = None
- if args.layout:
- self.layout_predictor = LayoutPredictor(args)
- if args.ocr:
- self.text_system = TextSystem(args)
- # ---------- Custom adjust start ------------ #
- if args.table:
- if self.text_system is not None:
- self.table_system = CustomTableSystem(
- args,
- self.text_system.text_detector,
- self.text_system.text_recognizer,
- )
- else:
- self.table_system = CustomTableSystem(args)
- # ---------- Custom adjust end ------------- #
- elif self.mode == "kie":
- from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
- self.kie_predictor = SerRePredictor(args)
- def __call__(
- self,
- img,
- return_ocr_result_in_table=False,
- img_idx=0,
- prefer_table_cell_boxes=False,
- ):
- img = check_img(img)
- res, _ = self._super_call(
- img,
- return_ocr_result_in_table,
- img_idx=img_idx,
- prefer_table_cell_boxes=prefer_table_cell_boxes,
- )
- return res
- def _super_call(
- self,
- img,
- return_ocr_result_in_table=False,
- img_idx=0,
- prefer_table_cell_boxes=False,
- ):
- time_dict = {
- "image_orientation": 0,
- "layout": 0,
- "table": 0,
- "table_match": 0,
- "det": 0,
- "rec": 0,
- "kie": 0,
- "all": 0,
- }
- start = time.time()
- if self.image_orientation_predictor is not None:
- tic = time.time()
- cls_result = self.image_orientation_predictor.predict(
- input_data=img
- )
- cls_res = next(cls_result)
- angle = cls_res[0]["label_names"][0]
- cv_rotate_code = {
- "90": cv2.ROTATE_90_COUNTERCLOCKWISE,
- "180": cv2.ROTATE_180,
- "270": cv2.ROTATE_90_CLOCKWISE,
- }
- if angle in cv_rotate_code:
- img = cv2.rotate(img, cv_rotate_code[angle])
- toc = time.time()
- time_dict["image_orientation"] = toc - tic
- if self.mode == "structure":
- ori_im = img.copy()
- if self.layout_predictor is not None:
- layout_res, elapse = self.layout_predictor(img)
- time_dict["layout"] += elapse
- else:
- h, w = ori_im.shape[:2]
- layout_res = [dict(bbox=None, label="table")]
- res_list = []
- for region in layout_res:
- res = ""
- if region["bbox"] is not None:
- x1, y1, x2, y2 = region["bbox"]
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
- roi_img = ori_im[y1:y2, x1:x2, :]
- else:
- x1, y1, x2, y2 = 0, 0, w, h
- roi_img = ori_im
- if region["label"] == "table":
- if self.table_system is not None:
- # ---------- Custom adjust start ------------ #
- res, table_time_dict = self.table_system(
- roi_img,
- return_ocr_result_in_table,
- prefer_table_cell_boxes,
- )
- # ---------- Custom adjust end ------------- #
- time_dict["table"] += table_time_dict["table"]
- time_dict["table_match"] += table_time_dict["match"]
- time_dict["det"] += table_time_dict["det"]
- time_dict["rec"] += table_time_dict["rec"]
- else:
- if self.text_system is not None:
- if self.recovery:
- wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
- wht_im[y1:y2, x1:x2, :] = roi_img
- filter_boxes, filter_rec_res, ocr_time_dict = (
- self.text_system(wht_im)
- )
- else:
- filter_boxes, filter_rec_res, ocr_time_dict = (
- self.text_system(roi_img)
- )
- time_dict["det"] += ocr_time_dict["det"]
- time_dict["rec"] += ocr_time_dict["rec"]
- # remove style char,
- # when using the recognition model trained on the PubtabNet dataset,
- # it will recognize the text format in the table, such as <b>
- style_token = [
- "<strike>",
- "<strike>",
- "<sup>",
- "</sub>",
- "<b>",
- "</b>",
- "<sub>",
- "</sup>",
- "<overline>",
- "</overline>",
- "<underline>",
- "</underline>",
- "<i>",
- "</i>",
- ]
- res = []
- for box, rec_res in zip(filter_boxes, filter_rec_res):
- rec_str, rec_conf = rec_res
- for token in style_token:
- if token in rec_str:
- rec_str = rec_str.replace(token, "")
- if not self.recovery:
- box += [x1, y1]
- res.append(
- {
- "text": rec_str,
- "confidence": float(rec_conf),
- "text_region": box.tolist(),
- }
- )
- res_list.append(
- {
- "type": region["label"].lower(),
- "bbox": [x1, y1, x2, y2],
- "img": roi_img,
- "res": res,
- "img_idx": img_idx,
- }
- )
- end = time.time()
- time_dict["all"] = end - start
- return res_list, time_dict
- elif self.mode == "kie":
- re_res, elapse = self.kie_predictor(img)
- time_dict["kie"] = elapse
- time_dict["all"] = elapse
- return re_res[0], time_dict
- return None, None
- class CustomTableSystem(TableSystem):
- def __call__(
- self,
- img,
- return_ocr_result_in_table=False,
- prefer_table_cell_boxes=False,
- ):
- result = dict()
- time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0}
- start = time.time()
- structure_res, elapse = self._structure(copy.deepcopy(img))
- result["cell_bbox"] = structure_res[1].tolist()
- time_dict["table"] = elapse
- # ---------- Custom adjust start ------------ #
- if prefer_table_cell_boxes:
- cell_boxes = (
- np.array([[[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[6], b[7]]] for b in structure_res[1]])
- .astype(int)
- .astype(float)
- )
- else:
- cell_boxes = None
- dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
- copy.deepcopy(img), cell_boxes=cell_boxes
- )
- # ---------- Custom adjust end ------------- #
- time_dict["det"] = det_elapse
- time_dict["rec"] = rec_elapse
- if return_ocr_result_in_table:
- result["boxes"] = [x.tolist() for x in dt_boxes]
- result["rec_res"] = rec_res
- tic = time.time()
- pred_html = self.match(structure_res, dt_boxes, rec_res)
- toc = time.time()
- time_dict["match"] = toc - tic
- result["html"] = pred_html
- end = time.time()
- time_dict["all"] = end - start
- return result, time_dict
- def _ocr(self, img, cell_boxes=None):
- h, w = img.shape[:2]
- # ---------- Custom adjust start ------------ #
- if cell_boxes is None:
- dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
- else:
- logger.debug("use cell_boxes instead of dt_boxes")
- dt_boxes, det_elapse = cell_boxes, 0.0
- # ---------- Custom adjust end ------------- #
- dt_boxes = sorted_boxes(dt_boxes)
- r_boxes = []
- for box in dt_boxes:
- x_min = max(0, box[:, 0].min() - 1)
- x_max = min(w, box[:, 0].max() + 1)
- y_min = max(0, box[:, 1].min() - 1)
- y_max = min(h, box[:, 1].max() + 1)
- box = [x_min, y_min, x_max, y_max]
- r_boxes.append(box)
- dt_boxes = np.array(r_boxes)
- logger.debug(
- "dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse)
- )
- if dt_boxes is None:
- return None, None
- img_crop_list = []
- for i in range(len(dt_boxes)):
- det_box = dt_boxes[i]
- x0, y0, x1, y1 = expand(2, det_box, img.shape)
- text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :]
- img_crop_list.append(text_rect)
- rec_res, rec_elapse = self.text_recognizer(img_crop_list)
- logger.debug(
- "rec_res num : {}, elapse : {}".format(len(rec_res), rec_elapse)
- )
- return dt_boxes, rec_res, det_elapse, rec_elapse
|