import copy import time import os from pathlib import Path import logging import numpy as np import paddleocr from paddleocr.paddleocr import ( PPStructure, parse_args, check_gpu, parse_lang, get_model_config, confirm_model_dir_url, maybe_download, check_img, BASE_DIR, SUPPORT_STRUCTURE_MODEL_VERSION, ) from ppstructure.table.predict_table import ( TableSystem, expand, sorted_boxes, logger, ) from ppstructure.predict_system import LayoutPredictor, TextSystem class CustomPPStructure(PPStructure): def __init__(self, **kwargs): params = parse_args(mMain=False) params.__dict__.update(**kwargs) assert ( params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION ), "structure_version must in {}, but get {}".format( SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version ) params.use_gpu = check_gpu(params.use_gpu) params.mode = "structure" if not params.show_log: logger.setLevel(logging.INFO) lang, det_lang = parse_lang(params.lang) if lang == "ch": table_lang = "ch" else: table_lang = "en" if params.structure_version == "PP-Structure": params.merge_no_span_structure = False # init model dir det_model_config = get_model_config( "OCR", params.ocr_version, "det", det_lang ) params.det_model_dir, det_url = confirm_model_dir_url( params.det_model_dir, os.path.join(BASE_DIR, "whl", "det", det_lang), det_model_config["url"], ) rec_model_config = get_model_config( "OCR", params.ocr_version, "rec", lang ) params.rec_model_dir, rec_url = confirm_model_dir_url( params.rec_model_dir, os.path.join(BASE_DIR, "whl", "rec", lang), rec_model_config["url"], ) table_model_config = get_model_config( "STRUCTURE", params.structure_version, "table", table_lang ) params.table_model_dir, table_url = confirm_model_dir_url( params.table_model_dir, os.path.join(BASE_DIR, "whl", "table"), table_model_config["url"], ) layout_model_config = get_model_config( "STRUCTURE", params.structure_version, "layout", lang ) params.layout_model_dir, layout_url = confirm_model_dir_url( params.layout_model_dir, os.path.join(BASE_DIR, "whl", "layout"), layout_model_config["url"], ) # download model maybe_download(params.det_model_dir, det_url) maybe_download(params.rec_model_dir, rec_url) maybe_download(params.table_model_dir, table_url) maybe_download(params.layout_model_dir, layout_url) # ---------- Custom adjust start ------------ # paddleocr_path = Path(paddleocr.__file__).parent if params.rec_char_dict_path is None: params.rec_char_dict_path = str( paddleocr_path / rec_model_config["dict_path"] ) if params.table_char_dict_path is None: params.table_char_dict_path = str( paddleocr_path / table_model_config["dict_path"] ) if params.layout_dict_path is None: params.layout_dict_path = str( paddleocr_path / layout_model_config["dict_path"] ) # ---------- Custom adjust end ------------- # logger.debug(params) # Code from ppstructure.predict_system.StructureSystem args = params self.mode = args.mode self.recovery = args.recovery self.image_orientation_predictor = None if args.image_orientation: import paddleclas self.image_orientation_predictor = paddleclas.PaddleClas( model_name="text_image_orientation" ) if self.mode == "structure": if not args.show_log: logger.setLevel(logging.INFO) if args.layout == False and args.ocr == True: args.ocr = False logger.warning( "When args.layout is false, args.ocr is automatically set to false" ) args.drop_score = 0 # init model self.layout_predictor = None self.text_system = None self.table_system = None if args.layout: self.layout_predictor = LayoutPredictor(args) if args.ocr: self.text_system = TextSystem(args) # ---------- Custom adjust start ------------ # if args.table: if self.text_system is not None: self.table_system = CustomTableSystem( args, self.text_system.text_detector, self.text_system.text_recognizer, ) else: self.table_system = CustomTableSystem(args) # ---------- Custom adjust end ------------- # elif self.mode == "kie": from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor self.kie_predictor = SerRePredictor(args) def __call__( self, img, return_ocr_result_in_table=False, img_idx=0, prefer_table_cell_boxes=False, ): img = check_img(img) res, _ = self._super_call( img, return_ocr_result_in_table, img_idx=img_idx, prefer_table_cell_boxes=prefer_table_cell_boxes, ) return res def _super_call( self, img, return_ocr_result_in_table=False, img_idx=0, prefer_table_cell_boxes=False, ): time_dict = { "image_orientation": 0, "layout": 0, "table": 0, "table_match": 0, "det": 0, "rec": 0, "kie": 0, "all": 0, } start = time.time() if self.image_orientation_predictor is not None: tic = time.time() cls_result = self.image_orientation_predictor.predict( input_data=img ) cls_res = next(cls_result) angle = cls_res[0]["label_names"][0] cv_rotate_code = { "90": cv2.ROTATE_90_COUNTERCLOCKWISE, "180": cv2.ROTATE_180, "270": cv2.ROTATE_90_CLOCKWISE, } if angle in cv_rotate_code: img = cv2.rotate(img, cv_rotate_code[angle]) toc = time.time() time_dict["image_orientation"] = toc - tic if self.mode == "structure": ori_im = img.copy() if self.layout_predictor is not None: layout_res, elapse = self.layout_predictor(img) time_dict["layout"] += elapse else: h, w = ori_im.shape[:2] layout_res = [dict(bbox=None, label="table")] res_list = [] for region in layout_res: res = "" if region["bbox"] is not None: x1, y1, x2, y2 = region["bbox"] x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) roi_img = ori_im[y1:y2, x1:x2, :] else: x1, y1, x2, y2 = 0, 0, w, h roi_img = ori_im if region["label"] == "table": if self.table_system is not None: # ---------- Custom adjust start ------------ # res, table_time_dict = self.table_system( roi_img, return_ocr_result_in_table, prefer_table_cell_boxes, ) # ---------- Custom adjust end ------------- # time_dict["table"] += table_time_dict["table"] time_dict["table_match"] += table_time_dict["match"] time_dict["det"] += table_time_dict["det"] time_dict["rec"] += table_time_dict["rec"] else: if self.text_system is not None: if self.recovery: wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype) wht_im[y1:y2, x1:x2, :] = roi_img filter_boxes, filter_rec_res, ocr_time_dict = ( self.text_system(wht_im) ) else: filter_boxes, filter_rec_res, ocr_time_dict = ( self.text_system(roi_img) ) time_dict["det"] += ocr_time_dict["det"] time_dict["rec"] += ocr_time_dict["rec"] # remove style char, # when using the recognition model trained on the PubtabNet dataset, # it will recognize the text format in the table, such as style_token = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "", ] res = [] for box, rec_res in zip(filter_boxes, filter_rec_res): rec_str, rec_conf = rec_res for token in style_token: if token in rec_str: rec_str = rec_str.replace(token, "") if not self.recovery: box += [x1, y1] res.append( { "text": rec_str, "confidence": float(rec_conf), "text_region": box.tolist(), } ) res_list.append( { "type": region["label"].lower(), "bbox": [x1, y1, x2, y2], "img": roi_img, "res": res, "img_idx": img_idx, } ) end = time.time() time_dict["all"] = end - start return res_list, time_dict elif self.mode == "kie": re_res, elapse = self.kie_predictor(img) time_dict["kie"] = elapse time_dict["all"] = elapse return re_res[0], time_dict return None, None class CustomTableSystem(TableSystem): def __call__( self, img, return_ocr_result_in_table=False, prefer_table_cell_boxes=False, ): result = dict() time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0} start = time.time() structure_res, elapse = self._structure(copy.deepcopy(img)) result["cell_bbox"] = structure_res[1].tolist() time_dict["table"] = elapse # ---------- Custom adjust start ------------ # if prefer_table_cell_boxes: cell_boxes = ( np.array([[[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[6], b[7]]] for b in structure_res[1]]) .astype(int) .astype(float) ) else: cell_boxes = None dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( copy.deepcopy(img), cell_boxes=cell_boxes ) # ---------- Custom adjust end ------------- # time_dict["det"] = det_elapse time_dict["rec"] = rec_elapse if return_ocr_result_in_table: result["boxes"] = [x.tolist() for x in dt_boxes] result["rec_res"] = rec_res tic = time.time() pred_html = self.match(structure_res, dt_boxes, rec_res) toc = time.time() time_dict["match"] = toc - tic result["html"] = pred_html end = time.time() time_dict["all"] = end - start return result, time_dict def _ocr(self, img, cell_boxes=None): h, w = img.shape[:2] # ---------- Custom adjust start ------------ # if cell_boxes is None: dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img)) else: logger.debug("use cell_boxes instead of dt_boxes") dt_boxes, det_elapse = cell_boxes, 0.0 # ---------- Custom adjust end ------------- # dt_boxes = sorted_boxes(dt_boxes) r_boxes = [] for box in dt_boxes: x_min = max(0, box[:, 0].min() - 1) x_max = min(w, box[:, 0].max() + 1) y_min = max(0, box[:, 1].min() - 1) y_max = min(h, box[:, 1].max() + 1) box = [x_min, y_min, x_max, y_max] r_boxes.append(box) dt_boxes = np.array(r_boxes) logger.debug( "dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse) ) if dt_boxes is None: return None, None img_crop_list = [] for i in range(len(dt_boxes)): det_box = dt_boxes[i] x0, y0, x1, y1 = expand(2, det_box, img.shape) text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :] img_crop_list.append(text_rect) rec_res, rec_elapse = self.text_recognizer(img_crop_list) logger.debug( "rec_res num : {}, elapse : {}".format(len(rec_res), rec_elapse) ) return dt_boxes, rec_res, det_elapse, rec_elapse