# -*- coding: UTF-8 -*- import json from base64 import b64decode import base64 import cv2 import numpy as np from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from paddleocr import PaddleOCR, PPStructure from sx_utils.sxweb import * from sx_utils.sximage import * import threading import os import re from sx_utils.sx_log import * import paddleclas from cores.post_hander import * from cores.check_table import * format_print() # 初始化APP app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) table_engine_lock = threading.Lock() table_engine = PPStructure(layout=False, table=True, use_gpu=True, show_log=True, use_angle_cls=True, # det_model_dir="models/det/det_table_v2", # det_model_dir="models/det/det_table_v3", # rec_model_dir="models/rec/rec_table_v1", table_model_dir="models/table/SLANet_911") cls_lock = threading.Lock() cls_model = paddleclas.PaddleClas(model_name="text_image_orientation") # # 普通表格 # table_engine = PPStructure(layout=False, # table=True, # use_gpu=use_gpu, # show_log=True, # det_model_dir="models/det/det_table_v2", # rec_model_dir="./models/rec/rec_table_v1", # table_model_dir="models/table/SLANet_v2") # # # 长度较长表格 # table_engine1 = PPStructure(layout=False, # table=True, # use_gpu=use_gpu, # show_log=True, # det_model_dir="models/det/det_table_v1", # rec_model_dir="./models/rec/rec_table_v1", # table_model_dir="./models/table/SLAnet_v1") # # # 针对某些特殊情况的补充模型 # table_engine2 = PPStructure(layout=False, # table=True, # use_gpu=use_gpu, # show_log=True, # det_model_dir="models/det/det_table_v3", # rec_model_dir="./models/rec/rec_table_v1", # table_model_dir="./models/table/SLAnet_v1") # # # # 用于判断各个角度table的识别效果,识别的字段越多,效果越好 def cal_html_to_chs(html): res = [] rows = re.split('', html) for row in rows: row = re.split('', row) cells = list(map(lambda x: x.replace('', '').replace('', ''), row)) rec_str = ''.join(cells) for tag in ['', '', '', '', '', '
', '', '']: rec_str = rec_str.replace(tag, '') res.append(rec_str) rec_res = ''.join(res).replace(' ', '') rec_res = re.split('', rec_res) rec_res = ''.join(rec_res).replace(' ', '') print(rec_res) return len(rec_res) def predict_cls(image, conf=0.8): try: cls_lock.acquire() result = cls_model.predict(image) finally: cls_lock.release() for res in result: score = res[0]['scores'][0] label_name = res[0]['label_names'][0] print(f"score: {score}, label_name: {label_name}") # print(conf) if score > conf: return int(label_name) return -1 def rotate_to_zero(image, current_degree): # cv2.imwrite('1.jpg', image) current_degree = current_degree // 90 if current_degree == 0: return image to_rotate = (4 - current_degree) - 1 image = cv2.rotate(image, to_rotate) # cv2.imwrite('2.jpg', image) return image def get_zero_degree_image(img): step = 0.2 for idx, i in enumerate([-1, 0, 1, 2]): if i >= 0: image = cv2.rotate(img.copy(), i) else: image = img.copy() conf = 0.8 - (idx * step) current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来 if current_degree != -1: img = rotate_to_zero(image, current_degree) break else: continue return img def table_res(im, ROTATE=-1): im = im.copy() # cv2.imwrite('before-rotate.jpg', im) # 获取正向图片 img = get_zero_degree_image(im) # cv2.imwrite('after-rotate.jpg', img) try: table_engine_lock.acquire() res = table_engine(img) finally: table_engine_lock.release() html = res[0]['res']['html'] return res, html class TableInfo(BaseModel): image: str det: str @app.get("/ping") def ping(): return 'pong!!!!!!!!!' @app.post("/ocr_system/table") @web_try() def table(image: TableInfo): img = base64_to_np(image.image) res, html = table_res(img) # print(html) table = Table(html, img) if table.check_html(): res, html = table_res(table.img) if html: post_hander = PostHandler(html) # print(post_hander.format_predict_html) return {'html': post_hander.format_predict_html} else: raise Exception('无法识别') print('table system init success!')