|
@@ -46,13 +46,12 @@ bl = torch.cuda.is_available()
|
|
|
logger.info(f'是否可使用GPU=======>{bl}')
|
|
|
|
|
|
app = FastAPI()
|
|
|
-templates = Jinja2Templates(directory = 'templates')
|
|
|
+templates = Jinja2Templates(directory='templates')
|
|
|
|
|
|
model_selection_options = ['ocr-layout', 'ocr-logo']
|
|
|
-model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
|
|
|
-
|
|
|
-colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
|
|
|
+model_dict = {model_name: None for model_name in model_selection_options} # set up model cache
|
|
|
|
|
|
+colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] # for bbox plotting
|
|
|
|
|
|
if model_dict['ocr-layout'] is None:
|
|
|
model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
@@ -63,8 +62,9 @@ if model_dict['ocr-logo'] is None:
|
|
|
model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
|
|
|
logger.info("========>模型加载成功")
|
|
|
|
|
|
+
|
|
|
##############################################
|
|
|
-#-------------GET Request Routes--------------
|
|
|
+# -------------GET Request Routes--------------
|
|
|
##############################################
|
|
|
@app.get("/")
|
|
|
def home(request: Request):
|
|
@@ -72,9 +72,10 @@ def home(request: Request):
|
|
|
'''
|
|
|
|
|
|
return templates.TemplateResponse('home.html', {
|
|
|
- "request": request,
|
|
|
- "model_selection_options": model_selection_options,
|
|
|
- })
|
|
|
+ "request": request,
|
|
|
+ "model_selection_options": model_selection_options,
|
|
|
+ })
|
|
|
+
|
|
|
|
|
|
@app.get("/drag_and_drop_detect")
|
|
|
def drag_and_drop_detect(request: Request):
|
|
@@ -84,69 +85,67 @@ def drag_and_drop_detect(request: Request):
|
|
|
'''
|
|
|
|
|
|
return templates.TemplateResponse('drag_and_drop_detect.html',
|
|
|
- {"request": request,
|
|
|
- "model_selection_options": model_selection_options,
|
|
|
- })
|
|
|
+ {"request": request,
|
|
|
+ "model_selection_options": model_selection_options,
|
|
|
+ })
|
|
|
|
|
|
|
|
|
##############################################
|
|
|
-#------------POST Request Routes--------------
|
|
|
+# ------------POST Request Routes--------------
|
|
|
##############################################
|
|
|
@app.post("/")
|
|
|
def detect_via_web_form(request: Request,
|
|
|
file_list: List[UploadFile] = File(...),
|
|
|
model_name: str = Form(...),
|
|
|
img_size: int = Form(1824)):
|
|
|
-
|
|
|
'''
|
|
|
Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
|
|
|
Intended for human (non-api) users.
|
|
|
Returns: HTML template render showing bbox data and base64 encoded image
|
|
|
'''
|
|
|
|
|
|
- #assume input validated properly if we got here
|
|
|
+ # assume input validated properly if we got here
|
|
|
if model_dict[model_name] is None:
|
|
|
model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
|
|
|
img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
|
|
|
- for file in file_list]
|
|
|
+ for file in file_list]
|
|
|
|
|
|
- #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
|
|
|
- #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
+ # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
|
|
|
+ # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
|
|
|
|
|
|
- results = model_dict[model_name](img_batch_rgb, size = img_size)
|
|
|
+ results = model_dict[model_name](img_batch_rgb, size=img_size)
|
|
|
|
|
|
- json_results = results_to_json(results,model_dict[model_name])
|
|
|
+ json_results = results_to_json(results, model_dict[model_name])
|
|
|
|
|
|
img_str_list = []
|
|
|
- #plot bboxes on the image
|
|
|
+ # plot bboxes on the image
|
|
|
for img, bbox_list in zip(img_batch, json_results):
|
|
|
for bbox in bbox_list:
|
|
|
label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
|
|
|
plot_one_box(bbox['bbox'], img, label=label,
|
|
|
- color=colors[int(bbox['class'])], line_thickness=3)
|
|
|
+ color=colors[int(bbox['class'])], line_thickness=3)
|
|
|
|
|
|
img_str_list.append(base64EncodeImage(img))
|
|
|
|
|
|
- #escape the apostrophes in the json string representation
|
|
|
- encoded_json_results = str(json_results).replace("'",r"\'").replace('"',r'\"')
|
|
|
+ # escape the apostrophes in the json string representation
|
|
|
+ encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
|
|
|
|
|
|
return templates.TemplateResponse('show_results.html', {
|
|
|
- 'request': request,
|
|
|
- 'bbox_image_data_zipped': zip(img_str_list,json_results), #unzipped in jinja2 template
|
|
|
- 'bbox_data_str': encoded_json_results,
|
|
|
- })
|
|
|
+ 'request': request,
|
|
|
+ 'bbox_image_data_zipped': zip(img_str_list, json_results), # unzipped in jinja2 template
|
|
|
+ 'bbox_data_str': encoded_json_results,
|
|
|
+ })
|
|
|
|
|
|
|
|
|
@app.post("/detect")
|
|
|
@web_try()
|
|
|
def detect_via_api(request: Request,
|
|
|
- file_list: List[UploadFile] = File(...),
|
|
|
- model_name: str = Form(...),
|
|
|
- img_size: Optional[int] = Form(1920),
|
|
|
- download_image: Optional[bool] = Form(False)):
|
|
|
-
|
|
|
+ file_list: List[UploadFile] = File(...),
|
|
|
+ model_name: str = Form(...),
|
|
|
+ img_size: Optional[int] = Form(1920),
|
|
|
+ download_image: Optional[bool] = Form(False)):
|
|
|
'''
|
|
|
Requires an image file upload, model name (ex. yolov5s).
|
|
|
Optional image size parameter (Default 1824)
|
|
@@ -159,46 +158,47 @@ def detect_via_api(request: Request,
|
|
|
'''
|
|
|
|
|
|
img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
|
|
|
- for file in file_list]
|
|
|
+ for file in file_list]
|
|
|
|
|
|
- #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
|
|
|
- #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
+ # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
|
|
|
+ # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
|
|
|
img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
|
|
|
|
|
|
- results = model_dict[model_name](img_batch_rgb, size = img_size)
|
|
|
- json_results = results_to_json(results,model_dict[model_name])
|
|
|
+ results = model_dict[model_name](img_batch_rgb, size=img_size)
|
|
|
+ json_results = results_to_json(results, model_dict[model_name])
|
|
|
|
|
|
if download_image:
|
|
|
for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
|
|
|
for bbox in bbox_list:
|
|
|
label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
|
|
|
plot_one_box(bbox['bbox'], img, label=label,
|
|
|
- color=colors[int(bbox['class'])], line_thickness=3)
|
|
|
+ color=colors[int(bbox['class'])], line_thickness=3)
|
|
|
|
|
|
- payload = {'image_base64':base64EncodeImage(img)}
|
|
|
+ payload = {'image_base64': base64EncodeImage(img)}
|
|
|
json_results[idx].append(payload)
|
|
|
|
|
|
- encoded_json_results = str(json_results).replace("'",r'"')
|
|
|
+ encoded_json_results = str(json_results).replace("'", r'"')
|
|
|
return encoded_json_results
|
|
|
|
|
|
+
|
|
|
##############################################
|
|
|
-#--------------Helper Functions---------------
|
|
|
+# --------------Helper Functions---------------
|
|
|
##############################################
|
|
|
|
|
|
def results_to_json(results, model):
|
|
|
''' Converts yolo model output to json (list of list of dicts)'''
|
|
|
return [
|
|
|
- [
|
|
|
- {
|
|
|
- "class": int(pred[5]),
|
|
|
- "class_name": model.model.names[int(pred[5])],
|
|
|
- "bbox": [int(x) for x in pred[:4].tolist()], #convert bbox results to int from float
|
|
|
- "confidence": float(pred[4]),
|
|
|
- }
|
|
|
- for pred in result
|
|
|
- ]
|
|
|
- for result in results.xyxy
|
|
|
- ]
|
|
|
+ [
|
|
|
+ {
|
|
|
+ "class": int(pred[5]),
|
|
|
+ "class_name": model.model.names[int(pred[5])],
|
|
|
+ "bbox": [int(x) for x in pred[:4].tolist()], # convert bbox results to int from float
|
|
|
+ "confidence": float(pred[4]),
|
|
|
+ }
|
|
|
+ for pred in result
|
|
|
+ ]
|
|
|
+ for result in results.xyxy
|
|
|
+ ]
|
|
|
|
|
|
|
|
|
def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
|
|
@@ -223,19 +223,27 @@ def base64EncodeImage(img):
|
|
|
|
|
|
return im_b64
|
|
|
|
|
|
+
|
|
|
+@app.get("/ping", description="健康检查")
|
|
|
+def ping():
|
|
|
+ logger.info("->ping")
|
|
|
+ return "pong!"
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
import uvicorn
|
|
|
import argparse
|
|
|
+
|
|
|
parser = argparse.ArgumentParser()
|
|
|
- parser.add_argument('--host', default = 'localhost')
|
|
|
- parser.add_argument('--port', default = 8080)
|
|
|
+ parser.add_argument('--host', default='localhost')
|
|
|
+ parser.add_argument('--port', default=8080)
|
|
|
parser.add_argument('--precache-models', action='store_true',
|
|
|
- help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
|
|
|
+ help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
|
|
|
opt = parser.parse_args()
|
|
|
|
|
|
# if opt.precache_models:
|
|
|
# model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
|
|
|
# for model_name in model_selection_options}
|
|
|
|
|
|
- app_str = 'server:app' #make the app string equal to whatever the name of this file is
|
|
|
- uvicorn.run(app_str, host= opt.host, port=int(opt.port), reload=True)
|
|
|
+ app_str = 'server:app' # make the app string equal to whatever the name of this file is
|
|
|
+ uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|