|
@@ -1,44 +1,19 @@
|
|
-import io
|
|
|
|
import json
|
|
import json
|
|
-import re
|
|
|
|
-from fastapi import FastAPI, Request, File, UploadFile, Body
|
|
|
|
|
|
+from base64 import b64decode
|
|
|
|
+
|
|
|
|
+import cv2
|
|
|
|
+import numpy as np
|
|
|
|
+from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
-from sx_utils.sximage import *
|
|
|
|
-from sx_utils.sxtime import sxtimeit
|
|
|
|
-from sx_utils.sxweb import web_try
|
|
|
|
-import requests
|
|
|
|
-from PIL import Image
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
-import sys
|
|
|
|
-import logging
|
|
|
|
|
|
+from paddleocr import PaddleOCR, PPStructure
|
|
|
|
+from sx_utils.sxweb import *
|
|
|
|
+from sx_utils.sximage import *
|
|
|
|
+
|
|
import os
|
|
import os
|
|
-import cv2
|
|
|
|
-from paddleocr import PaddleOCR
|
|
|
|
-
|
|
|
|
-logger = logging.getLogger('log')
|
|
|
|
-logger.setLevel(logging.DEBUG)
|
|
|
|
-
|
|
|
|
-# 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
|
|
|
|
-while logger.hasHandlers():
|
|
|
|
- for i in logger.handlers:
|
|
|
|
- logger.removeHandler(i)
|
|
|
|
-
|
|
|
|
-# file log 写入文件配置
|
|
|
|
-formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') # 日志的格式
|
|
|
|
-# 本地运行时,这部分需注释
|
|
|
|
-# fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8') # 日志文件路径文件名称,编码格式
|
|
|
|
-# fh.setLevel(logging.DEBUG) # 日志打印级别
|
|
|
|
-# fh.setFormatter(formatter)
|
|
|
|
-# logger.addHandler(fh)
|
|
|
|
-
|
|
|
|
-# console log 控制台输出控制
|
|
|
|
-ch = logging.StreamHandler(sys.stdout)
|
|
|
|
-ch.setLevel(logging.DEBUG)
|
|
|
|
-ch.setFormatter(formatter)
|
|
|
|
-logger.addHandler(ch)
|
|
|
|
|
|
|
|
|
|
+# 初始化app
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
-
|
|
|
|
origins = ["*"]
|
|
origins = ["*"]
|
|
|
|
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
@@ -48,48 +23,41 @@ app.add_middleware(
|
|
allow_methods=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
)
|
|
|
|
+
|
|
use_gpu = False
|
|
use_gpu = False
|
|
if os.getenv('USE_CUDA') == 'gpu':
|
|
if os.getenv('USE_CUDA') == 'gpu':
|
|
use_gpu = True
|
|
use_gpu = True
|
|
|
|
|
|
-logger.info(f"->是否使用GPU:{use_gpu}")
|
|
|
|
|
|
+print(f'use gpu: {use_gpu}')
|
|
|
|
|
|
-ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
|
|
|
|
|
|
+# 初始化模型
|
|
|
|
+<<<<<<< HEAD
|
|
|
|
+table_engine = PPStructure(layout=False, table=True, show_log=True, table_model_dir="/Users/sxkj/opt/python-workspace/yili-ocr/ocr-table/SLANet/")
|
|
|
|
+=======
|
|
|
|
+table_engine = PPStructure(layout=False,
|
|
|
|
+ table=True,
|
|
|
|
+ use_gpu=use_gpu,
|
|
|
|
+ show_log=True,
|
|
|
|
+ table_model_dir="./tabel_ocr_infer")
|
|
|
|
+>>>>>>> d046d2d76465ba7ba4d8c3e2cdad1ec5e7ac3b69
|
|
|
|
|
|
|
|
|
|
-@app.get("/ping")
|
|
|
|
-def ping():
|
|
|
|
- return "pong!"
|
|
|
|
|
|
+class TableInfo(BaseModel):
|
|
|
|
+ image: str
|
|
|
|
|
|
|
|
|
|
|
|
+@app.get("/ping")
|
|
|
|
+def ping():
|
|
|
|
+ return 'pong!'
|
|
|
|
|
|
|
|
|
|
-class ImageListInfo(BaseModel):
|
|
|
|
- images: list
|
|
|
|
- img_type: str
|
|
|
|
|
|
|
|
-@app.post("/ocr_system/paddle")
|
|
|
|
-@sxtimeit
|
|
|
|
|
|
+@app.post("/ocr_system/table")
|
|
@web_try()
|
|
@web_try()
|
|
-def paddle(request: Request,info: ImageListInfo):
|
|
|
|
- logger.info(f"->图片数量:{len(info.images)}")
|
|
|
|
- res_list = []
|
|
|
|
- for b_img in info.images:
|
|
|
|
- img = base64_to_np(b_img)
|
|
|
|
- result=ocr.ocr(img,cls=True)
|
|
|
|
- r_list = []
|
|
|
|
- for text_list in result:
|
|
|
|
- if len(text_list) >= 1:
|
|
|
|
- data = {}
|
|
|
|
- data["confidence"]= text_list[1][1]
|
|
|
|
- data["text"] = text_list[1][0]
|
|
|
|
- data["type"] = info.img_type
|
|
|
|
- data["text_region"]= text_list[0]
|
|
|
|
- r_list.append(data)
|
|
|
|
- res_list.append(r_list)
|
|
|
|
- return res_list
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+def table(image: TableInfo):
|
|
|
|
+ img = base64_to_np(image.image)
|
|
|
|
+ res = table_engine(img)
|
|
|
|
+ return res[0]['res']
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
@@ -101,6 +69,5 @@ if __name__ == '__main__':
|
|
parser.add_argument('--port', default=8080)
|
|
parser.add_argument('--port', default=8080)
|
|
opt = parser.parse_args()
|
|
opt = parser.parse_args()
|
|
|
|
|
|
-
|
|
|
|
app_str = 'server:app' # make the app string equal to whatever the name of this file is
|
|
app_str = 'server:app' # make the app string equal to whatever the name of this file is
|
|
uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|
|
uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|