123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- # -*- coding: UTF-8 -*-
- from fastapi import FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel, Field
- from cores.table_engine import CustomPPStructure
- from sx_utils.sxweb import *
- from sx_utils.sximage import *
- import threading
- from sx_utils.sx_log import *
- import paddleclas
- from cores.post_hander import *
- from cores.check_table import *
- format_print()
- # 初始化APP
- app = FastAPI()
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- table_engine_lock = threading.Lock()
- # 表格识别模型
- table_engine = CustomPPStructure(layout=False,
- table=True,
- use_gpu=True,
- show_log=True,
- use_angle_cls=True,
- table_model_dir="models/table/SLANet_911")
- cls_lock = threading.Lock()
- cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
- # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
- def cal_html_to_chs(html):
- """
- 将HTML中的表格数据提取并合并为中文字符串。
- Parameters:
- html (str): 输入的HTML字符串。
- Returns:
- int: 合并后的中文字符串长度。
- """
- res = []
- rows = re.split('<tr>', html)
- for row in rows:
- row = re.split('<td>', row)
- cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
- rec_str = ''.join(cells)
- for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
- rec_str = rec_str.replace(tag, '')
- res.append(rec_str)
- rec_res = ''.join(res).replace(' ', '')
- rec_res = re.split('<tdcolspan="\w+">', rec_res)
- rec_res = ''.join(rec_res).replace(' ', '')
- return len(rec_res)
- def predict_cls(image, conf=0.8):
- """
- 使用分类模型对图像进行预测,并返回预测结果。
- Parameters:
- image (np.ndarray): 输入的图像数组。
- conf (float): 置信度阈值,默认为0.8。
- Returns:
- int: 预测结果的类别标签。
- """
- try:
- cls_lock.acquire()
- result = cls_model.predict(image)
- finally:
- cls_lock.release()
- for res in result:
- score = res[0]['scores'][0]
- label_name = res[0]['label_names'][0]
- print(f"score: {score}, label_name: {label_name}")
- if score > conf:
- return int(label_name)
- return -1
- def rotate_to_zero(image, current_degree):
- """
- 将图像旋转至零度方向。
- Parameters:
- image (np.ndarray): 输入的图像数组。
- current_degree (float): 当前的旋转角度。
- Returns:
- np.ndarray: 旋转后的图像数组。
- """
- current_degree = current_degree // 90
- if current_degree == 0:
- return image
- to_rotate = (4 - current_degree) - 1
- image = cv2.rotate(image, to_rotate)
- return image
- def get_zero_degree_image(img):
- """
- 获取经零度方向旋转后的图像。
- Parameters:
- img (np.ndarray): 输入的图像数组。
- Returns:
- np.ndarray: 经零度方向旋转后的图像数组。
- """
- step = 0.2
- for idx, i in enumerate([-1, 0, 1, 2]):
- if i >= 0:
- image = cv2.rotate(img.copy(), i)
- else:
- image = img.copy()
- conf = 0.8 - (idx * step)
- current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
- if current_degree != -1:
- img = rotate_to_zero(image, current_degree)
- break
- else:
- continue
- return img
- def table_res(im, ROTATE=-1, prefer_cell=False):
- """
- 获取图像中表格的识别结果和HTML字符串。
- Parameters:
- im (np.ndarray): 输入的图像数组。
- ROTATE (int): 旋转角度,默认为-1。
- prefer_cell (bool): 是否使用cell_boxes替代dt_boxes,默认为否。
- Returns:
- Tuple: 表格识别结果和HTML字符串。
- """
- im = im.copy()
- img = get_zero_degree_image(im)
- try:
- table_engine_lock.acquire()
- res = table_engine(img, prefer_table_cell_boxes=prefer_cell)
- finally:
- table_engine_lock.release()
- html = res[0]['res']['html']
- return res, html
- class TableInfo(BaseModel):
- image: str
- det: str
- prefer_cell: bool = Field(
- default=False,
- description="是否使用cell_boxes替代dt_boxes"
- )
- @app.get("/ping")
- def ping():
- """
- 用于检查服务是否存活的端点。
- Returns:
- str: 返回pong表示服务存活。
- """
- return 'pong!!!!!!!!!'
- @app.post("/ocr_system/table")
- @web_try()
- def table(info: TableInfo):
- """
- 对图像中的表格进行识别并返回HTML字符串。
- Parameters:
- info (TableInfo): 输入的图像信息。
- Returns:
- dict: 包含HTML字符串的字典。
- """
- # 转换图片格式
- img = base64_to_np(info.image)
- # 进行表格识别
- res, html = table_res(img, prefer_cell=info.prefer_cell)
- # 创建Table实例
- table = Table(html, img)
- # 效果不好则重新识别
- if table.check_html():
- res, html = table_res(table.img, prefer_cell=info.prefer_cell)
- if html:
- post_handler = PostHandler(html)
- return {'html': post_handler.format_predict_html}
- else:
- raise Exception('无法识别')
- print('table system init success!')
|