# -*- coding: UTF-8 -*-
import json
from base64 import b64decode
import base64
import cv2
import numpy as np
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from paddleocr import PaddleOCR, PPStructure
from sx_utils.sxweb import *
from sx_utils.sximage import *
import threading
import os
import re
from sx_utils.sx_log import *
import paddleclas
from cores.post_hander import *
from cores.check_table import *
format_print()
# 初始化APP
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
table_engine_lock = threading.Lock()
table_engine = PPStructure(layout=False,
table=True,
use_gpu=True,
show_log=True,
use_angle_cls=True,
# det_model_dir="models/det/det_table_v2",
# det_model_dir="models/det/det_table_v3",
# rec_model_dir="models/rec/rec_table_v1",
table_model_dir="models/table/SLANet_911")
cls_lock = threading.Lock()
cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
# # 普通表格
# table_engine = PPStructure(layout=False,
# table=True,
# use_gpu=use_gpu,
# show_log=True,
# det_model_dir="models/det/det_table_v2",
# rec_model_dir="./models/rec/rec_table_v1",
# table_model_dir="models/table/SLANet_v2")
#
# # 长度较长表格
# table_engine1 = PPStructure(layout=False,
# table=True,
# use_gpu=use_gpu,
# show_log=True,
# det_model_dir="models/det/det_table_v1",
# rec_model_dir="./models/rec/rec_table_v1",
# table_model_dir="./models/table/SLAnet_v1")
#
# # 针对某些特殊情况的补充模型
# table_engine2 = PPStructure(layout=False,
# table=True,
# use_gpu=use_gpu,
# show_log=True,
# det_model_dir="models/det/det_table_v3",
# rec_model_dir="./models/rec/rec_table_v1",
# table_model_dir="./models/table/SLAnet_v1")
#
#
#
# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
def cal_html_to_chs(html):
res = []
rows = re.split('
', html)
for row in rows:
row = re.split('', row)
cells = list(map(lambda x: x.replace(' | ', '').replace('
', ''), row))
rec_str = ''.join(cells)
for tag in ['', '', '', '', '', '', '']:
rec_str = rec_str.replace(tag, '')
res.append(rec_str)
rec_res = ''.join(res).replace(' ', '')
rec_res = re.split('', rec_res)
rec_res = ''.join(rec_res).replace(' ', '')
print(rec_res)
return len(rec_res)
def predict_cls(image, conf=0.8):
try:
cls_lock.acquire()
result = cls_model.predict(image)
finally:
cls_lock.release()
for res in result:
score = res[0]['scores'][0]
label_name = res[0]['label_names'][0]
print(f"score: {score}, label_name: {label_name}")
# print(conf)
if score > conf:
return int(label_name)
return -1
def rotate_to_zero(image, current_degree):
# cv2.imwrite('1.jpg', image)
current_degree = current_degree // 90
if current_degree == 0:
return image
to_rotate = (4 - current_degree) - 1
image = cv2.rotate(image, to_rotate)
# cv2.imwrite('2.jpg', image)
return image
def get_zero_degree_image(img):
step = 0.2
for idx, i in enumerate([-1, 0, 1, 2]):
if i >= 0:
image = cv2.rotate(img.copy(), i)
else:
image = img.copy()
conf = 0.8 - (idx * step)
current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
if current_degree != -1:
img = rotate_to_zero(image, current_degree)
break
else:
continue
return img
def table_res(im, ROTATE=-1):
im = im.copy()
# cv2.imwrite('before-rotate.jpg', im)
# 获取正向图片
img = get_zero_degree_image(im)
# cv2.imwrite('after-rotate.jpg', img)
try:
table_engine_lock.acquire()
res = table_engine(img)
finally:
table_engine_lock.release()
html = res[0]['res']['html']
return res, html
class TableInfo(BaseModel):
image: str
det: str
@app.get("/ping")
def ping():
return 'pong!!!!!!!!!'
@app.post("/ocr_system/table")
@web_try()
def table(image: TableInfo):
img = base64_to_np(image.image)
res, html = table_res(img)
# print(html)
table = Table(html, img)
if table.check_html():
res, html = table_res(table.img)
if html:
post_hander = PostHandler(html)
# print(post_hander.format_predict_html)
return {'html': post_hander.format_predict_html}
else:
raise Exception('无法识别')
print('table system init success!')