# -*- coding: UTF-8 -*- from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from paddleocr import PPStructure from sx_utils.sxweb import * from sx_utils.sximage import * import threading from sx_utils.sx_log import * import paddleclas from cores.post_hander import * from cores.check_table import * format_print() # 初始化APP app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) table_engine_lock = threading.Lock() # 表格识别模型 table_engine = PPStructure(layout=False, table=True, use_gpu=True, show_log=True, use_angle_cls=True, det_model_dir="models/det/det_pse_quad", table_model_dir="models/table/SLANet_911") cls_lock = threading.Lock() cls_model = paddleclas.PaddleClas(model_name="text_image_orientation") # 用于判断各个角度table的识别效果,识别的字段越多,效果越好 def cal_html_to_chs(html): """ 将HTML中的表格数据提取并合并为中文字符串。 Parameters: html (str): 输入的HTML字符串。 Returns: int: 合并后的中文字符串长度。 """ res = [] rows = re.split('', html) for row in rows: row = re.split('', row) cells = list(map(lambda x: x.replace('', '').replace('', ''), row)) rec_str = ''.join(cells) for tag in ['', '', '', '', '', '
', '', '']: rec_str = rec_str.replace(tag, '') res.append(rec_str) rec_res = ''.join(res).replace(' ', '') rec_res = re.split('', rec_res) rec_res = ''.join(rec_res).replace(' ', '') return len(rec_res) def predict_cls(image, conf=0.8): """ 使用分类模型对图像进行预测,并返回预测结果。 Parameters: image (np.ndarray): 输入的图像数组。 conf (float): 置信度阈值,默认为0.8。 Returns: int: 预测结果的类别标签。 """ try: cls_lock.acquire() result = cls_model.predict(image) finally: cls_lock.release() for res in result: score = res[0]['scores'][0] label_name = res[0]['label_names'][0] print(f"score: {score}, label_name: {label_name}") if score > conf: return int(label_name) return -1 def rotate_to_zero(image, current_degree): """ 将图像旋转至零度方向。 Parameters: image (np.ndarray): 输入的图像数组。 current_degree (float): 当前的旋转角度。 Returns: np.ndarray: 旋转后的图像数组。 """ current_degree = current_degree // 90 if current_degree == 0: return image to_rotate = (4 - current_degree) - 1 image = cv2.rotate(image, to_rotate) return image def get_zero_degree_image(img): """ 获取经零度方向旋转后的图像。 Parameters: img (np.ndarray): 输入的图像数组。 Returns: np.ndarray: 经零度方向旋转后的图像数组。 """ step = 0.2 for idx, i in enumerate([-1, 0, 1, 2]): if i >= 0: image = cv2.rotate(img.copy(), i) else: image = img.copy() conf = 0.8 - (idx * step) current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来 if current_degree != -1: img = rotate_to_zero(image, current_degree) break else: continue return img def table_res(im, ROTATE=-1): """ 获取图像中表格的识别结果和HTML字符串。 Parameters: im (np.ndarray): 输入的图像数组。 ROTATE (int): 旋转角度,默认为-1。 Returns: Tuple: 表格识别结果和HTML字符串。 """ im = im.copy() img = get_zero_degree_image(im) try: table_engine_lock.acquire() res = table_engine(img) finally: table_engine_lock.release() html = res[0]['res']['html'] return res, html class TableInfo(BaseModel): image: str det: str @app.get("/ping") def ping(): """ 用于检查服务是否存活的端点。 Returns: str: 返回pong表示服务存活。 """ return 'pong!!!!!!!!!' @app.post("/ocr_system/table") @web_try() def table(image: TableInfo): """ 对图像中的表格进行识别并返回HTML字符串。 Parameters: image (TableInfo): 输入的图像信息。 Returns: dict: 包含HTML字符串的字典。 """ # 转换图片格式 img = base64_to_np(image.image) # 进行表格识别 res, html = table_res(img) # 创建Table实例 table = Table(html, img) # 效果不好则重新识别 if table.check_html(): res, html = table_res(table.img) if html: post_handler = PostHandler(html) return {'html': post_handler.format_predict_html} else: raise Exception('无法识别') print('table system init success!')