Ver código fonte

modified: server.py
SLANet/

liutao 2 anos atrás
pai
commit
658d5ab26b
1 arquivos alterados com 32 adições e 65 exclusões
  1. 32 65
      server.py

+ 32 - 65
server.py

@@ -1,44 +1,19 @@
-import io
 import json
-import re
-from fastapi import FastAPI, Request, File, UploadFile, Body
+from base64 import b64decode
+
+import cv2
+import numpy as np
+from fastapi import FastAPI, Request
 from fastapi.middleware.cors import CORSMiddleware
-from sx_utils.sximage import *
-from sx_utils.sxtime import sxtimeit
-from sx_utils.sxweb import web_try
-import requests
-from PIL import Image
 from pydantic import BaseModel
-import sys
-import logging
+from paddleocr import PaddleOCR, PPStructure
+from sx_utils.sxweb import *
+from sx_utils.sximage import *
+
 import os
-import cv2
-from paddleocr import PaddleOCR
-
-logger = logging.getLogger('log')
-logger.setLevel(logging.DEBUG)
-
-# 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
-while logger.hasHandlers():
-    for i in logger.handlers:
-        logger.removeHandler(i)
-
-# file log 写入文件配置
-formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # 日志的格式
-# 本地运行时,这部分需注释
-# fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8')  # 日志文件路径文件名称,编码格式
-# fh.setLevel(logging.DEBUG)  # 日志打印级别
-# fh.setFormatter(formatter)
-# logger.addHandler(fh)
-
-# console log 控制台输出控制
-ch = logging.StreamHandler(sys.stdout)
-ch.setLevel(logging.DEBUG)
-ch.setFormatter(formatter)
-logger.addHandler(ch)
 
+# 初始化app
 app = FastAPI()
-
 origins = ["*"]
 
 app.add_middleware(
@@ -48,48 +23,41 @@ app.add_middleware(
     allow_methods=["*"],
     allow_headers=["*"],
 )
+
 use_gpu = False
 if os.getenv('USE_CUDA') == 'gpu':
     use_gpu = True
 
-logger.info(f"->是否使用GPU:{use_gpu}")
+print(f'use gpu: {use_gpu}')
 
-ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
+# 初始化模型
+<<<<<<< HEAD
+table_engine = PPStructure(layout=False, table=True, show_log=True, table_model_dir="/Users/sxkj/opt/python-workspace/yili-ocr/ocr-table/SLANet/")
+=======
+table_engine = PPStructure(layout=False,
+                           table=True,
+                           use_gpu=use_gpu,
+                           show_log=True,
+                           table_model_dir="./tabel_ocr_infer")
+>>>>>>> d046d2d76465ba7ba4d8c3e2cdad1ec5e7ac3b69
 
 
-@app.get("/ping")
-def ping():
-    return "pong!"
+class TableInfo(BaseModel):
+    image: str
 
 
+@app.get("/ping")
+def ping():
+    return 'pong!'
 
 
-class ImageListInfo(BaseModel):
-    images: list
-    img_type: str
 
-@app.post("/ocr_system/paddle")
-@sxtimeit
+@app.post("/ocr_system/table")
 @web_try()
-def paddle(request: Request,info: ImageListInfo):
-    logger.info(f"->图片数量:{len(info.images)}")
-    res_list = []
-    for b_img in info.images:
-        img = base64_to_np(b_img)
-        result=ocr.ocr(img,cls=True)
-        r_list = []
-        for text_list in result:
-            if len(text_list) >= 1:
-                data = {}
-                data["confidence"]= text_list[1][1]
-                data["text"] = text_list[1][0]
-                data["type"] = info.img_type
-                data["text_region"]= text_list[0]
-                r_list.append(data)
-        res_list.append(r_list)
-    return res_list
-
-
+def table(image: TableInfo):
+    img = base64_to_np(image.image)
+    res = table_engine(img)
+    return res[0]['res']
 
 
 if __name__ == '__main__':
@@ -101,6 +69,5 @@ if __name__ == '__main__':
     parser.add_argument('--port', default=8080)
     opt = parser.parse_args()
 
-
     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)