table_engine.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import copy
  2. import time
  3. import os
  4. from pathlib import Path
  5. import logging
  6. import numpy as np
  7. import paddleocr
  8. from paddleocr.paddleocr import (
  9. PPStructure,
  10. parse_args,
  11. check_gpu,
  12. parse_lang,
  13. get_model_config,
  14. confirm_model_dir_url,
  15. maybe_download,
  16. check_img,
  17. BASE_DIR,
  18. SUPPORT_STRUCTURE_MODEL_VERSION,
  19. )
  20. from ppstructure.table.predict_table import (
  21. TableSystem,
  22. expand,
  23. sorted_boxes,
  24. logger,
  25. )
  26. from ppstructure.predict_system import LayoutPredictor, TextSystem
  27. class CustomPPStructure(PPStructure):
  28. def __init__(self, **kwargs):
  29. params = parse_args(mMain=False)
  30. params.__dict__.update(**kwargs)
  31. assert (
  32. params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION
  33. ), "structure_version must in {}, but get {}".format(
  34. SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version
  35. )
  36. params.use_gpu = check_gpu(params.use_gpu)
  37. params.mode = "structure"
  38. if not params.show_log:
  39. logger.setLevel(logging.INFO)
  40. lang, det_lang = parse_lang(params.lang)
  41. if lang == "ch":
  42. table_lang = "ch"
  43. else:
  44. table_lang = "en"
  45. if params.structure_version == "PP-Structure":
  46. params.merge_no_span_structure = False
  47. # init model dir
  48. det_model_config = get_model_config(
  49. "OCR", params.ocr_version, "det", det_lang
  50. )
  51. params.det_model_dir, det_url = confirm_model_dir_url(
  52. params.det_model_dir,
  53. os.path.join(BASE_DIR, "whl", "det", det_lang),
  54. det_model_config["url"],
  55. )
  56. rec_model_config = get_model_config(
  57. "OCR", params.ocr_version, "rec", lang
  58. )
  59. params.rec_model_dir, rec_url = confirm_model_dir_url(
  60. params.rec_model_dir,
  61. os.path.join(BASE_DIR, "whl", "rec", lang),
  62. rec_model_config["url"],
  63. )
  64. table_model_config = get_model_config(
  65. "STRUCTURE", params.structure_version, "table", table_lang
  66. )
  67. params.table_model_dir, table_url = confirm_model_dir_url(
  68. params.table_model_dir,
  69. os.path.join(BASE_DIR, "whl", "table"),
  70. table_model_config["url"],
  71. )
  72. layout_model_config = get_model_config(
  73. "STRUCTURE", params.structure_version, "layout", lang
  74. )
  75. params.layout_model_dir, layout_url = confirm_model_dir_url(
  76. params.layout_model_dir,
  77. os.path.join(BASE_DIR, "whl", "layout"),
  78. layout_model_config["url"],
  79. )
  80. # download model
  81. maybe_download(params.det_model_dir, det_url)
  82. maybe_download(params.rec_model_dir, rec_url)
  83. maybe_download(params.table_model_dir, table_url)
  84. maybe_download(params.layout_model_dir, layout_url)
  85. # ---------- Custom adjust start ------------ #
  86. paddleocr_path = Path(paddleocr.__file__).parent
  87. if params.rec_char_dict_path is None:
  88. params.rec_char_dict_path = str(
  89. paddleocr_path / rec_model_config["dict_path"]
  90. )
  91. if params.table_char_dict_path is None:
  92. params.table_char_dict_path = str(
  93. paddleocr_path / table_model_config["dict_path"]
  94. )
  95. if params.layout_dict_path is None:
  96. params.layout_dict_path = str(
  97. paddleocr_path / layout_model_config["dict_path"]
  98. )
  99. # ---------- Custom adjust end ------------- #
  100. logger.debug(params)
  101. # Code from ppstructure.predict_system.StructureSystem
  102. args = params
  103. self.mode = args.mode
  104. self.recovery = args.recovery
  105. self.image_orientation_predictor = None
  106. if args.image_orientation:
  107. import paddleclas
  108. self.image_orientation_predictor = paddleclas.PaddleClas(
  109. model_name="text_image_orientation"
  110. )
  111. if self.mode == "structure":
  112. if not args.show_log:
  113. logger.setLevel(logging.INFO)
  114. if args.layout == False and args.ocr == True:
  115. args.ocr = False
  116. logger.warning(
  117. "When args.layout is false, args.ocr is automatically set to false"
  118. )
  119. args.drop_score = 0
  120. # init model
  121. self.layout_predictor = None
  122. self.text_system = None
  123. self.table_system = None
  124. if args.layout:
  125. self.layout_predictor = LayoutPredictor(args)
  126. if args.ocr:
  127. self.text_system = TextSystem(args)
  128. # ---------- Custom adjust start ------------ #
  129. if args.table:
  130. if self.text_system is not None:
  131. self.table_system = CustomTableSystem(
  132. args,
  133. self.text_system.text_detector,
  134. self.text_system.text_recognizer,
  135. )
  136. else:
  137. self.table_system = CustomTableSystem(args)
  138. # ---------- Custom adjust end ------------- #
  139. elif self.mode == "kie":
  140. from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
  141. self.kie_predictor = SerRePredictor(args)
  142. def __call__(
  143. self,
  144. img,
  145. return_ocr_result_in_table=False,
  146. img_idx=0,
  147. prefer_table_cell_boxes=False,
  148. ):
  149. img = check_img(img)
  150. res, _ = self._super_call(
  151. img,
  152. return_ocr_result_in_table,
  153. img_idx=img_idx,
  154. prefer_table_cell_boxes=prefer_table_cell_boxes,
  155. )
  156. return res
  157. def _super_call(
  158. self,
  159. img,
  160. return_ocr_result_in_table=False,
  161. img_idx=0,
  162. prefer_table_cell_boxes=False,
  163. ):
  164. time_dict = {
  165. "image_orientation": 0,
  166. "layout": 0,
  167. "table": 0,
  168. "table_match": 0,
  169. "det": 0,
  170. "rec": 0,
  171. "kie": 0,
  172. "all": 0,
  173. }
  174. start = time.time()
  175. if self.image_orientation_predictor is not None:
  176. tic = time.time()
  177. cls_result = self.image_orientation_predictor.predict(
  178. input_data=img
  179. )
  180. cls_res = next(cls_result)
  181. angle = cls_res[0]["label_names"][0]
  182. cv_rotate_code = {
  183. "90": cv2.ROTATE_90_COUNTERCLOCKWISE,
  184. "180": cv2.ROTATE_180,
  185. "270": cv2.ROTATE_90_CLOCKWISE,
  186. }
  187. if angle in cv_rotate_code:
  188. img = cv2.rotate(img, cv_rotate_code[angle])
  189. toc = time.time()
  190. time_dict["image_orientation"] = toc - tic
  191. if self.mode == "structure":
  192. ori_im = img.copy()
  193. if self.layout_predictor is not None:
  194. layout_res, elapse = self.layout_predictor(img)
  195. time_dict["layout"] += elapse
  196. else:
  197. h, w = ori_im.shape[:2]
  198. layout_res = [dict(bbox=None, label="table")]
  199. res_list = []
  200. for region in layout_res:
  201. res = ""
  202. if region["bbox"] is not None:
  203. x1, y1, x2, y2 = region["bbox"]
  204. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  205. roi_img = ori_im[y1:y2, x1:x2, :]
  206. else:
  207. x1, y1, x2, y2 = 0, 0, w, h
  208. roi_img = ori_im
  209. if region["label"] == "table":
  210. if self.table_system is not None:
  211. # ---------- Custom adjust start ------------ #
  212. res, table_time_dict = self.table_system(
  213. roi_img,
  214. return_ocr_result_in_table,
  215. prefer_table_cell_boxes,
  216. )
  217. # ---------- Custom adjust end ------------- #
  218. time_dict["table"] += table_time_dict["table"]
  219. time_dict["table_match"] += table_time_dict["match"]
  220. time_dict["det"] += table_time_dict["det"]
  221. time_dict["rec"] += table_time_dict["rec"]
  222. else:
  223. if self.text_system is not None:
  224. if self.recovery:
  225. wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
  226. wht_im[y1:y2, x1:x2, :] = roi_img
  227. filter_boxes, filter_rec_res, ocr_time_dict = (
  228. self.text_system(wht_im)
  229. )
  230. else:
  231. filter_boxes, filter_rec_res, ocr_time_dict = (
  232. self.text_system(roi_img)
  233. )
  234. time_dict["det"] += ocr_time_dict["det"]
  235. time_dict["rec"] += ocr_time_dict["rec"]
  236. # remove style char,
  237. # when using the recognition model trained on the PubtabNet dataset,
  238. # it will recognize the text format in the table, such as <b>
  239. style_token = [
  240. "<strike>",
  241. "<strike>",
  242. "<sup>",
  243. "</sub>",
  244. "<b>",
  245. "</b>",
  246. "<sub>",
  247. "</sup>",
  248. "<overline>",
  249. "</overline>",
  250. "<underline>",
  251. "</underline>",
  252. "<i>",
  253. "</i>",
  254. ]
  255. res = []
  256. for box, rec_res in zip(filter_boxes, filter_rec_res):
  257. rec_str, rec_conf = rec_res
  258. for token in style_token:
  259. if token in rec_str:
  260. rec_str = rec_str.replace(token, "")
  261. if not self.recovery:
  262. box += [x1, y1]
  263. res.append(
  264. {
  265. "text": rec_str,
  266. "confidence": float(rec_conf),
  267. "text_region": box.tolist(),
  268. }
  269. )
  270. res_list.append(
  271. {
  272. "type": region["label"].lower(),
  273. "bbox": [x1, y1, x2, y2],
  274. "img": roi_img,
  275. "res": res,
  276. "img_idx": img_idx,
  277. }
  278. )
  279. end = time.time()
  280. time_dict["all"] = end - start
  281. return res_list, time_dict
  282. elif self.mode == "kie":
  283. re_res, elapse = self.kie_predictor(img)
  284. time_dict["kie"] = elapse
  285. time_dict["all"] = elapse
  286. return re_res[0], time_dict
  287. return None, None
  288. class CustomTableSystem(TableSystem):
  289. def __call__(
  290. self,
  291. img,
  292. return_ocr_result_in_table=False,
  293. prefer_table_cell_boxes=False,
  294. ):
  295. result = dict()
  296. time_dict = {"det": 0, "rec": 0, "table": 0, "all": 0, "match": 0}
  297. start = time.time()
  298. structure_res, elapse = self._structure(copy.deepcopy(img))
  299. result["cell_bbox"] = structure_res[1].tolist()
  300. time_dict["table"] = elapse
  301. # ---------- Custom adjust start ------------ #
  302. if prefer_table_cell_boxes:
  303. cell_boxes = (
  304. np.array([[[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[6], b[7]]] for b in structure_res[1]])
  305. .astype(int)
  306. .astype(float)
  307. )
  308. else:
  309. cell_boxes = None
  310. dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
  311. copy.deepcopy(img), cell_boxes=cell_boxes
  312. )
  313. # ---------- Custom adjust end ------------- #
  314. time_dict["det"] = det_elapse
  315. time_dict["rec"] = rec_elapse
  316. if return_ocr_result_in_table:
  317. result["boxes"] = [x.tolist() for x in dt_boxes]
  318. result["rec_res"] = rec_res
  319. tic = time.time()
  320. pred_html = self.match(structure_res, dt_boxes, rec_res)
  321. toc = time.time()
  322. time_dict["match"] = toc - tic
  323. result["html"] = pred_html
  324. end = time.time()
  325. time_dict["all"] = end - start
  326. return result, time_dict
  327. def _ocr(self, img, cell_boxes=None):
  328. h, w = img.shape[:2]
  329. # ---------- Custom adjust start ------------ #
  330. if cell_boxes is None:
  331. dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
  332. else:
  333. logger.debug("use cell_boxes instead of dt_boxes")
  334. dt_boxes, det_elapse = cell_boxes, 0.0
  335. # ---------- Custom adjust end ------------- #
  336. dt_boxes = sorted_boxes(dt_boxes)
  337. r_boxes = []
  338. for box in dt_boxes:
  339. x_min = max(0, box[:, 0].min() - 1)
  340. x_max = min(w, box[:, 0].max() + 1)
  341. y_min = max(0, box[:, 1].min() - 1)
  342. y_max = min(h, box[:, 1].max() + 1)
  343. box = [x_min, y_min, x_max, y_max]
  344. r_boxes.append(box)
  345. dt_boxes = np.array(r_boxes)
  346. logger.debug(
  347. "dt_boxes num : {}, elapse : {}".format(len(dt_boxes), det_elapse)
  348. )
  349. if dt_boxes is None:
  350. return None, None
  351. img_crop_list = []
  352. for i in range(len(dt_boxes)):
  353. det_box = dt_boxes[i]
  354. x0, y0, x1, y1 = expand(2, det_box, img.shape)
  355. text_rect = img[int(y0) : int(y1), int(x0) : int(x1), :]
  356. img_crop_list.append(text_rect)
  357. rec_res, rec_elapse = self.text_recognizer(img_crop_list)
  358. logger.debug(
  359. "rec_res num : {}, elapse : {}".format(len(rec_res), rec_elapse)
  360. )
  361. return dt_boxes, rec_res, det_elapse, rec_elapse