xujiayue пре 1 година
родитељ
комит
1c4f4ae13f
1 измењених фајлова са 65 додато и 57 уклоњено
  1. 65 57
      server.py

+ 65 - 57
server.py

@@ -46,13 +46,12 @@ bl = torch.cuda.is_available()
 logger.info(f'是否可使用GPU=======>{bl}')
 
 app = FastAPI()
-templates = Jinja2Templates(directory = 'templates')
+templates = Jinja2Templates(directory='templates')
 
 model_selection_options = ['ocr-layout', 'ocr-logo']
-model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
-
-colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
+model_dict = {model_name: None for model_name in model_selection_options}  # set up model cache
 
+colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)]  # for bbox plotting
 
 if model_dict['ocr-layout'] is None:
     model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
@@ -63,8 +62,9 @@ if model_dict['ocr-logo'] is None:
     model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
     logger.info("========>模型加载成功")
 
+
 ##############################################
-#-------------GET Request Routes--------------
+# -------------GET Request Routes--------------
 ##############################################
 @app.get("/")
 def home(request: Request):
@@ -72,9 +72,10 @@ def home(request: Request):
     '''
 
     return templates.TemplateResponse('home.html', {
-            "request": request,
-            "model_selection_options": model_selection_options,
-        })
+        "request": request,
+        "model_selection_options": model_selection_options,
+    })
+
 
 @app.get("/drag_and_drop_detect")
 def drag_and_drop_detect(request: Request):
@@ -84,69 +85,67 @@ def drag_and_drop_detect(request: Request):
     '''
 
     return templates.TemplateResponse('drag_and_drop_detect.html',
-            {"request": request,
-            "model_selection_options": model_selection_options,
-        })
+                                      {"request": request,
+                                       "model_selection_options": model_selection_options,
+                                       })
 
 
 ##############################################
-#------------POST Request Routes--------------
+# ------------POST Request Routes--------------
 ##############################################
 @app.post("/")
 def detect_via_web_form(request: Request,
                         file_list: List[UploadFile] = File(...),
                         model_name: str = Form(...),
                         img_size: int = Form(1824)):
-
     '''
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Intended for human (non-api) users.
     Returns: HTML template render showing bbox data and base64 encoded image
     '''
 
-    #assume input validated properly if we got here
+    # assume input validated properly if we got here
     if model_dict[model_name] is None:
         model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
 
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
-                    for file in file_list]
+                 for file in file_list]
 
-    #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
-    #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
+    # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
+    # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
-    results = model_dict[model_name](img_batch_rgb, size = img_size)
+    results = model_dict[model_name](img_batch_rgb, size=img_size)
 
-    json_results = results_to_json(results,model_dict[model_name])
+    json_results = results_to_json(results, model_dict[model_name])
 
     img_str_list = []
-    #plot bboxes on the image
+    # plot bboxes on the image
     for img, bbox_list in zip(img_batch, json_results):
         for bbox in bbox_list:
             label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
             plot_one_box(bbox['bbox'], img, label=label,
-                    color=colors[int(bbox['class'])], line_thickness=3)
+                         color=colors[int(bbox['class'])], line_thickness=3)
 
         img_str_list.append(base64EncodeImage(img))
 
-    #escape the apostrophes in the json string representation
-    encoded_json_results = str(json_results).replace("'",r"\'").replace('"',r'\"')
+    # escape the apostrophes in the json string representation
+    encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
 
     return templates.TemplateResponse('show_results.html', {
-            'request': request,
-            'bbox_image_data_zipped': zip(img_str_list,json_results), #unzipped in jinja2 template
-            'bbox_data_str': encoded_json_results,
-        })
+        'request': request,
+        'bbox_image_data_zipped': zip(img_str_list, json_results),  # unzipped in jinja2 template
+        'bbox_data_str': encoded_json_results,
+    })
 
 
 @app.post("/detect")
 @web_try()
 def detect_via_api(request: Request,
-                file_list: List[UploadFile] = File(...),
-                model_name: str = Form(...),
-                img_size: Optional[int] = Form(1920),
-                download_image: Optional[bool] = Form(False)):
-
+                   file_list: List[UploadFile] = File(...),
+                   model_name: str = Form(...),
+                   img_size: Optional[int] = Form(1920),
+                   download_image: Optional[bool] = Form(False)):
     '''
     Requires an image file upload, model name (ex. yolov5s).
     Optional image size parameter (Default 1824)
@@ -159,46 +158,47 @@ def detect_via_api(request: Request,
     '''
 
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
-                for file in file_list]
+                 for file in file_list]
 
-    #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
-    #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
+    # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
+    # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
-    results = model_dict[model_name](img_batch_rgb, size = img_size)
-    json_results = results_to_json(results,model_dict[model_name])
+    results = model_dict[model_name](img_batch_rgb, size=img_size)
+    json_results = results_to_json(results, model_dict[model_name])
 
     if download_image:
         for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
             for bbox in bbox_list:
                 label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
                 plot_one_box(bbox['bbox'], img, label=label,
-                        color=colors[int(bbox['class'])], line_thickness=3)
+                             color=colors[int(bbox['class'])], line_thickness=3)
 
-            payload = {'image_base64':base64EncodeImage(img)}
+            payload = {'image_base64': base64EncodeImage(img)}
             json_results[idx].append(payload)
 
-    encoded_json_results = str(json_results).replace("'",r'"')
+    encoded_json_results = str(json_results).replace("'", r'"')
     return encoded_json_results
 
+
 ##############################################
-#--------------Helper Functions---------------
+# --------------Helper Functions---------------
 ##############################################
 
 def results_to_json(results, model):
     ''' Converts yolo model output to json (list of list of dicts)'''
     return [
-                [
-                    {
-                    "class": int(pred[5]),
-                    "class_name": model.model.names[int(pred[5])],
-                    "bbox": [int(x) for x in pred[:4].tolist()], #convert bbox results to int from float
-                    "confidence": float(pred[4]),
-                    }
-                for pred in result
-                ]
-            for result in results.xyxy
-            ]
+        [
+            {
+                "class": int(pred[5]),
+                "class_name": model.model.names[int(pred[5])],
+                "bbox": [int(x) for x in pred[:4].tolist()],  # convert bbox results to int from float
+                "confidence": float(pred[4]),
+            }
+            for pred in result
+        ]
+        for result in results.xyxy
+    ]
 
 
 def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
@@ -223,19 +223,27 @@ def base64EncodeImage(img):
 
     return im_b64
 
+
+@app.get("/ping", description="健康检查")
+def ping():
+    logger.info("->ping")
+    return "pong!"
+
+
 if __name__ == '__main__':
     import uvicorn
     import argparse
+
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', default = 'localhost')
-    parser.add_argument('--port', default = 8080)
+    parser.add_argument('--host', default='localhost')
+    parser.add_argument('--port', default=8080)
     parser.add_argument('--precache-models', action='store_true',
-            help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
+                        help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
     opt = parser.parse_args()
 
     # if opt.precache_models:
     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
     #                     for model_name in model_selection_options}
 
-    app_str = 'server:app' #make the app string equal to whatever the name of this file is
-    uvicorn.run(app_str, host= opt.host, port=int(opt.port), reload=True)
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)