|
@@ -9,8 +9,13 @@ from pydantic import BaseModel
|
|
from paddleocr import PaddleOCR, PPStructure
|
|
from paddleocr import PaddleOCR, PPStructure
|
|
from sx_utils.sxweb import *
|
|
from sx_utils.sxweb import *
|
|
from sx_utils.sximage import *
|
|
from sx_utils.sximage import *
|
|
-
|
|
|
|
|
|
+import threading
|
|
import os
|
|
import os
|
|
|
|
+import re
|
|
|
|
+
|
|
|
|
+table_engine_lock = threading.Lock()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
# 初始化app
|
|
# 初始化app
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
@@ -24,36 +29,78 @@ app.add_middleware(
|
|
allow_headers=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
)
|
|
|
|
|
|
-use_gpu = os.getenv('USE_CUDA') == 'gpu'
|
|
|
|
-print(f'use gpu: {use_gpu}')
|
|
|
|
|
|
|
|
-# 普通表格
|
|
|
|
|
|
+
|
|
|
|
+
|
|
table_engine = PPStructure(layout=False,
|
|
table_engine = PPStructure(layout=False,
|
|
table=True,
|
|
table=True,
|
|
- use_gpu=use_gpu,
|
|
|
|
|
|
+ use_gpu=True,
|
|
show_log=True,
|
|
show_log=True,
|
|
- det_model_dir="models/det/det_table_v2",
|
|
|
|
- rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
- table_model_dir="models/table/SLAnet_v1")
|
|
|
|
-
|
|
|
|
-# 长度较长表格
|
|
|
|
-table_engine1 = PPStructure(layout=False,
|
|
|
|
- table=True,
|
|
|
|
- use_gpu=use_gpu,
|
|
|
|
- show_log=True,
|
|
|
|
- det_model_dir="models/det/det_table_v1",
|
|
|
|
- rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
- table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
-
|
|
|
|
-# 针对某些特殊情况的补充模型
|
|
|
|
-table_engine2 = PPStructure(layout=False,
|
|
|
|
- table=True,
|
|
|
|
- use_gpu=use_gpu,
|
|
|
|
- show_log=True,
|
|
|
|
- det_model_dir="models/det/det_table_v3",
|
|
|
|
- rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
- table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
-
|
|
|
|
|
|
+ # det_model_dir="models/det/det_table_v2",
|
|
|
|
+ # rec_model_dir="models/rec/rec_table_v1",
|
|
|
|
+ table_model_dir="models/table/SLANet_829")
|
|
|
|
+
|
|
|
|
+# # 普通表格
|
|
|
|
+# table_engine = PPStructure(layout=False,
|
|
|
|
+# table=True,
|
|
|
|
+# use_gpu=use_gpu,
|
|
|
|
+# show_log=True,
|
|
|
|
+# det_model_dir="models/det/det_table_v2",
|
|
|
|
+# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
+# table_model_dir="models/table/SLANet_v2")
|
|
|
|
+#
|
|
|
|
+# # 长度较长表格
|
|
|
|
+# table_engine1 = PPStructure(layout=False,
|
|
|
|
+# table=True,
|
|
|
|
+# use_gpu=use_gpu,
|
|
|
|
+# show_log=True,
|
|
|
|
+# det_model_dir="models/det/det_table_v1",
|
|
|
|
+# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
+# table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
+#
|
|
|
|
+# # 针对某些特殊情况的补充模型
|
|
|
|
+# table_engine2 = PPStructure(layout=False,
|
|
|
|
+# table=True,
|
|
|
|
+# use_gpu=use_gpu,
|
|
|
|
+# show_log=True,
|
|
|
|
+# det_model_dir="models/det/det_table_v3",
|
|
|
|
+# rec_model_dir="./models/rec/rec_table_v1",
|
|
|
|
+# table_model_dir="./models/table/SLAnet_v1")
|
|
|
|
+#
|
|
|
|
+#
|
|
|
|
+#
|
|
|
|
+
|
|
|
|
+# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
|
|
|
|
+def cal_html_to_chs(html):
|
|
|
|
+ res = []
|
|
|
|
+ rows = re.split('<tr>', html)
|
|
|
|
+ for row in rows:
|
|
|
|
+ row = re.split('<td>', row)
|
|
|
|
+ cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
|
|
|
|
+ rec_str = ''.join(cells)
|
|
|
|
+ for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
|
|
|
|
+ rec_str = rec_str.replace(tag, '')
|
|
|
|
+
|
|
|
|
+ res.append(rec_str)
|
|
|
|
+
|
|
|
|
+ rec_res = ''.join(res).replace(' ', '')
|
|
|
|
+ rec_res = re.split('<tdcolspan="\w+">', rec_res)
|
|
|
|
+ rec_res = ''.join(rec_res).replace(' ', '')
|
|
|
|
+ print(rec_res)
|
|
|
|
+ return len(rec_res)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def table_res(im, ROTATE=-1):
|
|
|
|
+ im = im.copy()
|
|
|
|
+ if ROTATE >= 0:
|
|
|
|
+ im = cv2.rotate(im, ROTATE)
|
|
|
|
+ try:
|
|
|
|
+ table_engine_lock.acquire()
|
|
|
|
+ res = table_engine(im)
|
|
|
|
+ finally:
|
|
|
|
+ table_engine_lock.release()
|
|
|
|
+ html = res[0]['res']['html']
|
|
|
|
+ return res, html
|
|
|
|
|
|
class TableInfo(BaseModel):
|
|
class TableInfo(BaseModel):
|
|
image: str
|
|
image: str
|
|
@@ -62,30 +109,27 @@ class TableInfo(BaseModel):
|
|
|
|
|
|
@app.get("/ping")
|
|
@app.get("/ping")
|
|
def ping():
|
|
def ping():
|
|
- return 'pong!'
|
|
|
|
|
|
+ return 'pong!!!!!!!!!'
|
|
|
|
|
|
|
|
|
|
@app.post("/ocr_system/table")
|
|
@app.post("/ocr_system/table")
|
|
@web_try()
|
|
@web_try()
|
|
def table(image: TableInfo):
|
|
def table(image: TableInfo):
|
|
img = base64_to_np(image.image)
|
|
img = base64_to_np(image.image)
|
|
- if image.det == 'no':
|
|
|
|
- res = table_engine(img)
|
|
|
|
- elif image.det == 'yes':
|
|
|
|
- res = table_engine1(img)
|
|
|
|
- elif image.det == 'spe':
|
|
|
|
- res = table_engine2(img)
|
|
|
|
- return res[0]['res']
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
- import uvicorn
|
|
|
|
- import argparse
|
|
|
|
-
|
|
|
|
- parser = argparse.ArgumentParser()
|
|
|
|
- parser.add_argument('--host', default='0.0.0.0')
|
|
|
|
- parser.add_argument('--port', default=8080)
|
|
|
|
- opt = parser.parse_args()
|
|
|
|
-
|
|
|
|
- app_str = 'server:app' # make the app string equal to whatever the name of this file is
|
|
|
|
- uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|
|
|
|
|
|
+ res_len = 0
|
|
|
|
+ res = None
|
|
|
|
+ for i in [-1, 0, 1, 2]:
|
|
|
|
+ _res, html = table_res(img, i)
|
|
|
|
+ print(html)
|
|
|
|
+ _res_len = cal_html_to_chs(html)
|
|
|
|
+ if _res_len > res_len:
|
|
|
|
+ res = _res
|
|
|
|
+ res_len = _res_len
|
|
|
|
+
|
|
|
|
+ if res:
|
|
|
|
+ return res[0]['res']
|
|
|
|
+ else:
|
|
|
|
+ raise Exception('无法识别')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+print('table system init success!')
|