|
@@ -1,3 +1,4 @@
|
|
|
+# -*- coding:utf-8 -*-
|
|
|
from fastapi import FastAPI, Request, Form, File, UploadFile
|
|
|
from fastapi.templating import Jinja2Templates
|
|
|
from pydantic import BaseModel
|
|
@@ -42,7 +43,7 @@ logger.addHandler(ch)
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
bl = torch.cuda.is_available()
|
|
|
-logger.info(f'是否可使用GPU=======>{bl}')
|
|
|
+#logger.info(f'是否可使用GPU=======>{bl}')
|
|
|
|
|
|
app = FastAPI()
|
|
|
templates = Jinja2Templates(directory = 'templates')
|
|
@@ -55,6 +56,7 @@ colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)]
|
|
|
|
|
|
if model_dict['ocr-layout'] is None:
|
|
|
model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
+# print(model_dict['ocr-layout'])
|
|
|
logger.info("========>模型加载成功")
|
|
|
|
|
|
|
|
@@ -91,7 +93,7 @@ def drag_and_drop_detect(request: Request):
|
|
|
def detect_via_web_form(request: Request,
|
|
|
file_list: List[UploadFile] = File(...),
|
|
|
model_name: str = Form(...),
|
|
|
- img_size: int = Form(1824)):
|
|
|
+ img_size: int = Form(800)):
|
|
|
|
|
|
'''
|
|
|
Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
|
|
@@ -109,7 +111,8 @@ def detect_via_web_form(request: Request,
|
|
|
#create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
|
|
|
#using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
|
|
|
-
|
|
|
+ print(img_size)
|
|
|
+ print('111111')
|
|
|
results = model_dict[model_name](img_batch_rgb, size = img_size)
|
|
|
|
|
|
json_results = results_to_json(results,model_dict[model_name])
|
|
@@ -139,7 +142,7 @@ def detect_via_web_form(request: Request,
|
|
|
def detect_via_api(request: Request,
|
|
|
file_list: List[UploadFile] = File(...),
|
|
|
model_name: str = Form(...),
|
|
|
- img_size: Optional[int] = Form(1824),
|
|
|
+ img_size: Optional[int] = Form(800),
|
|
|
download_image: Optional[bool] = Form(False)):
|
|
|
|
|
|
'''
|
|
@@ -159,7 +162,10 @@ def detect_via_api(request: Request,
|
|
|
#create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
|
|
|
#using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
|
|
|
-
|
|
|
+ print('111111')
|
|
|
+ print(img_size)
|
|
|
+ img_size= 800
|
|
|
+ print(img_size)
|
|
|
results = model_dict[model_name](img_batch_rgb, size = img_size)
|
|
|
json_results = results_to_json(results,model_dict[model_name])
|
|
|
|