|
@@ -13,6 +13,8 @@ import threading
|
|
|
import os
|
|
|
import re
|
|
|
from sx_utils.sx_log import *
|
|
|
+import paddleclas
|
|
|
+
|
|
|
|
|
|
|
|
|
|
|
@@ -41,6 +43,10 @@ table_engine = PPStructure(layout=False,
|
|
|
# rec_model_dir="models/rec/rec_table_v1",
|
|
|
table_model_dir="models/table/SLANet_905")
|
|
|
|
|
|
+
|
|
|
+cls_lock = threading.Lock()
|
|
|
+
|
|
|
+cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
|
|
# # 普通表格
|
|
|
# table_engine = PPStructure(layout=False,
|
|
|
# table=True,
|
|
@@ -91,13 +97,54 @@ def cal_html_to_chs(html):
|
|
|
return len(rec_res)
|
|
|
|
|
|
|
|
|
+
|
|
|
+def predict_cls(image, conf=0.8):
|
|
|
+ try:
|
|
|
+ cls_lock.acquire()
|
|
|
+ result = cls_model.predict(image)
|
|
|
+ finally:
|
|
|
+ cls_lock.release()
|
|
|
+ for res in result:
|
|
|
+ score = res[0]['scores'][0]
|
|
|
+ label_name = res[0]['label_names'][0]
|
|
|
+ print(f'score: {score}, label_name: {label_name}')
|
|
|
+ if score > conf:
|
|
|
+ return int(label_name)
|
|
|
+ return -1
|
|
|
+
|
|
|
+
|
|
|
+def rotate_to_zero(image, current_degree):
|
|
|
+ current_degree = current_degree // 90
|
|
|
+ if current_degree == 0:
|
|
|
+ return image
|
|
|
+ to_rotate = (4 - current_degree) - 1
|
|
|
+ image = cv2.rotate(image, to_rotate)
|
|
|
+ return image
|
|
|
+
|
|
|
+
|
|
|
+def get_zero_degree_image(img):
|
|
|
+ step = 0.2
|
|
|
+ for idx, i in enumerate([-1, 0, 1, 2]):
|
|
|
+ if i >= 0:
|
|
|
+ image = cv2.rotate(img.copy(), i)
|
|
|
+ else:
|
|
|
+ image = img.copy()
|
|
|
+ conf = 0.8 - (idx * step)
|
|
|
+ current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
|
|
|
+ if current_degree != -1:
|
|
|
+ img = rotate_to_zero(image, current_degree)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+ return img
|
|
|
+
|
|
|
def table_res(im, ROTATE=-1):
|
|
|
im = im.copy()
|
|
|
- if ROTATE >= 0:
|
|
|
- im = cv2.rotate(im, ROTATE)
|
|
|
+ # 获取正向图片
|
|
|
+ img = get_zero_degree_image(im)
|
|
|
try:
|
|
|
table_engine_lock.acquire()
|
|
|
- res = table_engine(im)
|
|
|
+ res = table_engine(img)
|
|
|
finally:
|
|
|
table_engine_lock.release()
|
|
|
html = res[0]['res']['html']
|
|
@@ -117,16 +164,7 @@ def ping():
|
|
|
@web_try()
|
|
|
def table(image: TableInfo):
|
|
|
img = base64_to_np(image.image)
|
|
|
- res_len = 0
|
|
|
- res = None
|
|
|
- for i in [-1, 0, 1, 2]:
|
|
|
- _res, html = table_res(img, i)
|
|
|
- print(html)
|
|
|
- _res_len = cal_html_to_chs(html)
|
|
|
- if _res_len > res_len:
|
|
|
- res = _res
|
|
|
- res_len = _res_len
|
|
|
-
|
|
|
+ res, html = table_res(img)
|
|
|
if res:
|
|
|
return res[0]['res']
|
|
|
else:
|