Kaynağa Gözat

Merge branch 'master' of http://gogsb.soaringnova.com/chenguilong/ocr-paddle

Conflicts:
	server.py

	modified:   server.py
liutao 2 yıl önce
ebeveyn
işleme
90d0eac532
1 değiştirilmiş dosya ile 35 ekleme ve 46 silme
  1. 35 46
      server.py

+ 35 - 46
server.py

@@ -1,44 +1,19 @@
-import io
 import json
-import re
-from fastapi import FastAPI, Request, File, UploadFile, Body
+from base64 import b64decode
+
+import cv2
+import numpy as np
+from fastapi import FastAPI, Request
 from fastapi.middleware.cors import CORSMiddleware
-from sx_utils.sximage import *
-from sx_utils.sxtime import sxtimeit
-from sx_utils.sxweb import web_try
-import requests
-from PIL import Image
 from pydantic import BaseModel
-import sys
-import logging
+from paddleocr import PaddleOCR, PPStructure
+from sx_utils.sxweb import *
+from sx_utils.sximage import *
+
 import os
-import cv2
-from paddleocr import PaddleOCR
-
-logger = logging.getLogger('log')
-logger.setLevel(logging.DEBUG)
-
-# 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
-while logger.hasHandlers():
-    for i in logger.handlers:
-        logger.removeHandler(i)
-
-# file log 写入文件配置
-formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # 日志的格式
-# 本地运行时,这部分需注释
-# fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8')  # 日志文件路径文件名称,编码格式
-# fh.setLevel(logging.DEBUG)  # 日志打印级别
-# fh.setFormatter(formatter)
-# logger.addHandler(fh)
-
-# console log 控制台输出控制
-ch = logging.StreamHandler(sys.stdout)
-ch.setLevel(logging.DEBUG)
-ch.setFormatter(formatter)
-logger.addHandler(ch)
 
+# 初始化app
 app = FastAPI()
-
 origins = ["*"]
 
 app.add_middleware(
@@ -48,29 +23,38 @@ app.add_middleware(
     allow_methods=["*"],
     allow_headers=["*"],
 )
+
 use_gpu = False
 if os.getenv('USE_CUDA') == 'gpu':
     use_gpu = True
 
-logger.info(f"->是否使用GPU:{use_gpu}")
+print(f'use gpu: {use_gpu}')
 
-ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
+# 初始化模型
+<<<<<<< HEAD
+table_engine = PPStructure(layout=False, table=True, show_log=True, table_model_dir="/Users/sxkj/opt/python-workspace/yili-ocr/ocr-table/SLANet/")
+=======
+table_engine = PPStructure(layout=False,
+                           table=True,
+                           use_gpu=use_gpu,
+                           show_log=True,
+                           table_model_dir="./tabel_ocr_infer")
+>>>>>>> d046d2d76465ba7ba4d8c3e2cdad1ec5e7ac3b69
 
 
-@app.get("/ping")
-def ping():
-    return "pong!"
+class TableInfo(BaseModel):
+    image: str
 
 
+@app.get("/ping")
+def ping():
+    return 'pong!'
 
 
-class ImageListInfo(BaseModel):
-    images: list
-    img_type: str
 
-@app.post("/ocr_system/paddle")
-@sxtimeit
+@app.post("/ocr_system/table")
 @web_try()
+<<<<<<< HEAD
 def rotate_bound_white_bg(self, image, angle):
     
     (h, w) = image.shape[:2]
@@ -230,6 +214,12 @@ def paddle(request: Request,info: ImageListInfo):
     return res_list
 
 
+=======
+def table(image: TableInfo):
+    img = base64_to_np(image.image)
+    res = table_engine(img)
+    return res[0]['res']
+>>>>>>> 658d5ab26bfc557357332fa8aea4b547f20c5ff2
 
 
 if __name__ == '__main__':
@@ -241,6 +231,5 @@ if __name__ == '__main__':
     parser.add_argument('--port', default=8080)
     opt = parser.parse_args()
 
-
     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)