|
@@ -1,19 +1,11 @@
|
|
# -*- coding: UTF-8 -*-
|
|
# -*- coding: UTF-8 -*-
|
|
-import json
|
|
|
|
-from base64 import b64decode
|
|
|
|
-import base64
|
|
|
|
-
|
|
|
|
-import cv2
|
|
|
|
-import numpy as np
|
|
|
|
-from fastapi import FastAPI, Request
|
|
|
|
|
|
+from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
-from paddleocr import PaddleOCR, PPStructure
|
|
|
|
|
|
+from paddleocr import PPStructure
|
|
from sx_utils.sxweb import *
|
|
from sx_utils.sxweb import *
|
|
from sx_utils.sximage import *
|
|
from sx_utils.sximage import *
|
|
import threading
|
|
import threading
|
|
-import os
|
|
|
|
-import re
|
|
|
|
from sx_utils.sx_log import *
|
|
from sx_utils.sx_log import *
|
|
import paddleclas
|
|
import paddleclas
|
|
|
|
|
|
@@ -35,15 +27,12 @@ app.add_middleware(
|
|
)
|
|
)
|
|
|
|
|
|
table_engine_lock = threading.Lock()
|
|
table_engine_lock = threading.Lock()
|
|
-
|
|
|
|
|
|
+# 表格识别模型
|
|
table_engine = PPStructure(layout=False,
|
|
table_engine = PPStructure(layout=False,
|
|
table=True,
|
|
table=True,
|
|
use_gpu=True,
|
|
use_gpu=True,
|
|
show_log=True,
|
|
show_log=True,
|
|
use_angle_cls=True,
|
|
use_angle_cls=True,
|
|
- # det_model_dir="models/det/det_table_v2",
|
|
|
|
- # det_model_dir="models/det/det_table_v3",
|
|
|
|
- # rec_model_dir="models/rec/rec_table_v1",
|
|
|
|
table_model_dir="models/table/SLANet_911")
|
|
table_model_dir="models/table/SLANet_911")
|
|
|
|
|
|
cls_lock = threading.Lock()
|
|
cls_lock = threading.Lock()
|
|
@@ -51,38 +40,17 @@ cls_lock = threading.Lock()
|
|
cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
|
cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
|
|
|
|
|
|
|
|
|
-# # 普通表格
|
|
|
|
-# table_engine = PPStructure(layout=False,
|
|
|
|
-# table=True,
|
|
|
|
-# use_gpu=use_gpu,
|
|
|
|
-# show_log=True,
|
|
|
|
-# det_model_dir="models/det/det_table_v2",
|
|
|
|
-# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
-# table_model_dir="models/table/SLANet_v2")
|
|
|
|
-#
|
|
|
|
-# # 长度较长表格
|
|
|
|
-# table_engine1 = PPStructure(layout=False,
|
|
|
|
-# table=True,
|
|
|
|
-# use_gpu=use_gpu,
|
|
|
|
-# show_log=True,
|
|
|
|
-# det_model_dir="models/det/det_table_v1",
|
|
|
|
-# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
-# table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
-#
|
|
|
|
-# # 针对某些特殊情况的补充模型
|
|
|
|
-# table_engine2 = PPStructure(layout=False,
|
|
|
|
-# table=True,
|
|
|
|
-# use_gpu=use_gpu,
|
|
|
|
-# show_log=True,
|
|
|
|
-# det_model_dir="models/det/det_table_v3",
|
|
|
|
-# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
-# table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
-#
|
|
|
|
-#
|
|
|
|
-#
|
|
|
|
-
|
|
|
|
# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
|
|
# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
|
|
def cal_html_to_chs(html):
|
|
def cal_html_to_chs(html):
|
|
|
|
+ """
|
|
|
|
+ 将HTML中的表格数据提取并合并为中文字符串。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ html (str): 输入的HTML字符串。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ int: 合并后的中文字符串长度。
|
|
|
|
+ """
|
|
res = []
|
|
res = []
|
|
rows = re.split('<tr>', html)
|
|
rows = re.split('<tr>', html)
|
|
for row in rows:
|
|
for row in rows:
|
|
@@ -97,11 +65,20 @@ def cal_html_to_chs(html):
|
|
rec_res = ''.join(res).replace(' ', '')
|
|
rec_res = ''.join(res).replace(' ', '')
|
|
rec_res = re.split('<tdcolspan="\w+">', rec_res)
|
|
rec_res = re.split('<tdcolspan="\w+">', rec_res)
|
|
rec_res = ''.join(rec_res).replace(' ', '')
|
|
rec_res = ''.join(rec_res).replace(' ', '')
|
|
- print(rec_res)
|
|
|
|
return len(rec_res)
|
|
return len(rec_res)
|
|
|
|
|
|
|
|
|
|
def predict_cls(image, conf=0.8):
|
|
def predict_cls(image, conf=0.8):
|
|
|
|
+ """
|
|
|
|
+ 使用分类模型对图像进行预测,并返回预测结果。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ image (np.ndarray): 输入的图像数组。
|
|
|
|
+ conf (float): 置信度阈值,默认为0.8。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ int: 预测结果的类别标签。
|
|
|
|
+ """
|
|
try:
|
|
try:
|
|
cls_lock.acquire()
|
|
cls_lock.acquire()
|
|
result = cls_model.predict(image)
|
|
result = cls_model.predict(image)
|
|
@@ -111,24 +88,40 @@ def predict_cls(image, conf=0.8):
|
|
score = res[0]['scores'][0]
|
|
score = res[0]['scores'][0]
|
|
label_name = res[0]['label_names'][0]
|
|
label_name = res[0]['label_names'][0]
|
|
print(f"score: {score}, label_name: {label_name}")
|
|
print(f"score: {score}, label_name: {label_name}")
|
|
- # print(conf)
|
|
|
|
if score > conf:
|
|
if score > conf:
|
|
return int(label_name)
|
|
return int(label_name)
|
|
return -1
|
|
return -1
|
|
|
|
|
|
|
|
|
|
def rotate_to_zero(image, current_degree):
|
|
def rotate_to_zero(image, current_degree):
|
|
- # cv2.imwrite('1.jpg', image)
|
|
|
|
|
|
+ """
|
|
|
|
+ 将图像旋转至零度方向。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ image (np.ndarray): 输入的图像数组。
|
|
|
|
+ current_degree (float): 当前的旋转角度。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ np.ndarray: 旋转后的图像数组。
|
|
|
|
+ """
|
|
current_degree = current_degree // 90
|
|
current_degree = current_degree // 90
|
|
if current_degree == 0:
|
|
if current_degree == 0:
|
|
return image
|
|
return image
|
|
to_rotate = (4 - current_degree) - 1
|
|
to_rotate = (4 - current_degree) - 1
|
|
image = cv2.rotate(image, to_rotate)
|
|
image = cv2.rotate(image, to_rotate)
|
|
- # cv2.imwrite('2.jpg', image)
|
|
|
|
return image
|
|
return image
|
|
|
|
|
|
|
|
|
|
def get_zero_degree_image(img):
|
|
def get_zero_degree_image(img):
|
|
|
|
+ """
|
|
|
|
+ 获取经零度方向旋转后的图像。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ img (np.ndarray): 输入的图像数组。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ np.ndarray: 经零度方向旋转后的图像数组。
|
|
|
|
+ """
|
|
step = 0.2
|
|
step = 0.2
|
|
for idx, i in enumerate([-1, 0, 1, 2]):
|
|
for idx, i in enumerate([-1, 0, 1, 2]):
|
|
if i >= 0:
|
|
if i >= 0:
|
|
@@ -146,11 +139,18 @@ def get_zero_degree_image(img):
|
|
|
|
|
|
|
|
|
|
def table_res(im, ROTATE=-1):
|
|
def table_res(im, ROTATE=-1):
|
|
|
|
+ """
|
|
|
|
+ 获取图像中表格的识别结果和HTML字符串。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ im (np.ndarray): 输入的图像数组。
|
|
|
|
+ ROTATE (int): 旋转角度,默认为-1。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ Tuple: 表格识别结果和HTML字符串。
|
|
|
|
+ """
|
|
im = im.copy()
|
|
im = im.copy()
|
|
- # cv2.imwrite('before-rotate.jpg', im)
|
|
|
|
- # 获取正向图片
|
|
|
|
img = get_zero_degree_image(im)
|
|
img = get_zero_degree_image(im)
|
|
- # cv2.imwrite('after-rotate.jpg', img)
|
|
|
|
try:
|
|
try:
|
|
table_engine_lock.acquire()
|
|
table_engine_lock.acquire()
|
|
res = table_engine(img)
|
|
res = table_engine(img)
|
|
@@ -167,23 +167,40 @@ class TableInfo(BaseModel):
|
|
|
|
|
|
@app.get("/ping")
|
|
@app.get("/ping")
|
|
def ping():
|
|
def ping():
|
|
|
|
+ """
|
|
|
|
+ 用于检查服务是否存活的端点。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ str: 返回pong表示服务存活。
|
|
|
|
+ """
|
|
return 'pong!!!!!!!!!'
|
|
return 'pong!!!!!!!!!'
|
|
|
|
|
|
|
|
|
|
@app.post("/ocr_system/table")
|
|
@app.post("/ocr_system/table")
|
|
@web_try()
|
|
@web_try()
|
|
def table(image: TableInfo):
|
|
def table(image: TableInfo):
|
|
|
|
+ """
|
|
|
|
+ 对图像中的表格进行识别并返回HTML字符串。
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ image (TableInfo): 输入的图像信息。
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ dict: 包含HTML字符串的字典。
|
|
|
|
+ """
|
|
|
|
+ # 转换图片格式
|
|
img = base64_to_np(image.image)
|
|
img = base64_to_np(image.image)
|
|
|
|
+ # 进行表格识别
|
|
res, html = table_res(img)
|
|
res, html = table_res(img)
|
|
- # print(html)
|
|
|
|
|
|
+ # 创建Table实例
|
|
table = Table(html, img)
|
|
table = Table(html, img)
|
|
|
|
+ # 效果不好则重新识别
|
|
if table.check_html():
|
|
if table.check_html():
|
|
res, html = table_res(table.img)
|
|
res, html = table_res(table.img)
|
|
|
|
|
|
if html:
|
|
if html:
|
|
- post_hander = PostHandler(html)
|
|
|
|
- # print(post_hander.format_predict_html)
|
|
|
|
- return {'html': post_hander.format_predict_html}
|
|
|
|
|
|
+ post_handler = PostHandler(html)
|
|
|
|
+ return {'html': post_handler.format_predict_html}
|
|
else:
|
|
else:
|
|
raise Exception('无法识别')
|
|
raise Exception('无法识别')
|
|
|
|
|