# -*- coding: UTF-8 -*-
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from cores.table_engine import CustomPPStructure
from sx_utils.sxweb import *
from sx_utils.sximage import *
import threading
from sx_utils.sx_log import *
import paddleclas
from cores.post_hander import *
from cores.check_table import *
format_print()
# 初始化APP
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
table_engine_lock = threading.Lock()
# 表格识别模型
table_engine = CustomPPStructure(layout=False,
table=True,
use_gpu=True,
show_log=True,
use_angle_cls=True,
table_model_dir="models/table/SLANet_911")
cls_lock = threading.Lock()
cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
def cal_html_to_chs(html):
"""
将HTML中的表格数据提取并合并为中文字符串。
Parameters:
html (str): 输入的HTML字符串。
Returns:
int: 合并后的中文字符串长度。
"""
res = []
rows = re.split('
', html)
for row in rows:
row = re.split('', row)
cells = list(map(lambda x: x.replace(' | ', '').replace('
', ''), row))
rec_str = ''.join(cells)
for tag in ['', '', '', '', '', '', '']:
rec_str = rec_str.replace(tag, '')
res.append(rec_str)
rec_res = ''.join(res).replace(' ', '')
rec_res = re.split('', rec_res)
rec_res = ''.join(rec_res).replace(' ', '')
return len(rec_res)
def predict_cls(image, conf=0.8):
"""
使用分类模型对图像进行预测,并返回预测结果。
Parameters:
image (np.ndarray): 输入的图像数组。
conf (float): 置信度阈值,默认为0.8。
Returns:
int: 预测结果的类别标签。
"""
try:
cls_lock.acquire()
result = cls_model.predict(image)
finally:
cls_lock.release()
for res in result:
score = res[0]['scores'][0]
label_name = res[0]['label_names'][0]
print(f"score: {score}, label_name: {label_name}")
if score > conf:
return int(label_name)
return -1
def rotate_to_zero(image, current_degree):
"""
将图像旋转至零度方向。
Parameters:
image (np.ndarray): 输入的图像数组。
current_degree (float): 当前的旋转角度。
Returns:
np.ndarray: 旋转后的图像数组。
"""
current_degree = current_degree // 90
if current_degree == 0:
return image
to_rotate = (4 - current_degree) - 1
image = cv2.rotate(image, to_rotate)
return image
def get_zero_degree_image(img):
"""
获取经零度方向旋转后的图像。
Parameters:
img (np.ndarray): 输入的图像数组。
Returns:
np.ndarray: 经零度方向旋转后的图像数组。
"""
step = 0.2
for idx, i in enumerate([-1, 0, 1, 2]):
if i >= 0:
image = cv2.rotate(img.copy(), i)
else:
image = img.copy()
conf = 0.8 - (idx * step)
current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
if current_degree != -1:
img = rotate_to_zero(image, current_degree)
break
else:
continue
return img
def table_res(im, ROTATE=-1, prefer_cell=False):
"""
获取图像中表格的识别结果和HTML字符串。
Parameters:
im (np.ndarray): 输入的图像数组。
ROTATE (int): 旋转角度,默认为-1。
prefer_cell (bool): 是否使用cell_boxes替代dt_boxes,默认为否。
Returns:
Tuple: 表格识别结果和HTML字符串。
"""
im = im.copy()
img = get_zero_degree_image(im)
try:
table_engine_lock.acquire()
res = table_engine(img, prefer_table_cell_boxes=prefer_cell)
finally:
table_engine_lock.release()
html = res[0]['res']['html']
return res, html
class TableInfo(BaseModel):
image: str
det: str
prefer_cell: bool = Field(
default=False,
description="是否使用cell_boxes替代dt_boxes"
)
@app.get("/ping")
def ping():
"""
用于检查服务是否存活的端点。
Returns:
str: 返回pong表示服务存活。
"""
return 'pong!!!!!!!!!'
@app.post("/ocr_system/table")
@web_try()
def table(info: TableInfo):
"""
对图像中的表格进行识别并返回HTML字符串。
Parameters:
info (TableInfo): 输入的图像信息。
Returns:
dict: 包含HTML字符串的字典。
"""
# 转换图片格式
img = base64_to_np(info.image)
# 进行表格识别
res, html = table_res(img, prefer_cell=info.prefer_cell)
# 创建Table实例
table = Table(html, img)
# 效果不好则重新识别
if table.check_html():
res, html = table_res(table.img, prefer_cell=info.prefer_cell)
if html:
post_handler = PostHandler(html)
return {'html': post_handler.format_predict_html}
else:
raise Exception('无法识别')
print('table system init success!')