瀏覽代碼

修改内容:
健康检查

xujiayue 1 年之前
父節點
當前提交
1c4f4ae13f
共有 1 個文件被更改,包括 65 次插入57 次删除
  1. 65 57
      server.py

+ 65 - 57
server.py

@@ -46,13 +46,12 @@ bl = torch.cuda.is_available()
 logger.info(f'是否可使用GPU=======>{bl}')
 logger.info(f'是否可使用GPU=======>{bl}')
 
 
 app = FastAPI()
 app = FastAPI()
-templates = Jinja2Templates(directory = 'templates')
+templates = Jinja2Templates(directory='templates')
 
 
 model_selection_options = ['ocr-layout', 'ocr-logo']
 model_selection_options = ['ocr-layout', 'ocr-logo']
-model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
-
-colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
+model_dict = {model_name: None for model_name in model_selection_options}  # set up model cache
 
 
+colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)]  # for bbox plotting
 
 
 if model_dict['ocr-layout'] is None:
 if model_dict['ocr-layout'] is None:
     model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
     model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
@@ -63,8 +62,9 @@ if model_dict['ocr-logo'] is None:
     model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
     model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
     logger.info("========>模型加载成功")
     logger.info("========>模型加载成功")
 
 
+
 ##############################################
 ##############################################
-#-------------GET Request Routes--------------
+# -------------GET Request Routes--------------
 ##############################################
 ##############################################
 @app.get("/")
 @app.get("/")
 def home(request: Request):
 def home(request: Request):
@@ -72,9 +72,10 @@ def home(request: Request):
     '''
     '''
 
 
     return templates.TemplateResponse('home.html', {
     return templates.TemplateResponse('home.html', {
-            "request": request,
-            "model_selection_options": model_selection_options,
-        })
+        "request": request,
+        "model_selection_options": model_selection_options,
+    })
+
 
 
 @app.get("/drag_and_drop_detect")
 @app.get("/drag_and_drop_detect")
 def drag_and_drop_detect(request: Request):
 def drag_and_drop_detect(request: Request):
@@ -84,69 +85,67 @@ def drag_and_drop_detect(request: Request):
     '''
     '''
 
 
     return templates.TemplateResponse('drag_and_drop_detect.html',
     return templates.TemplateResponse('drag_and_drop_detect.html',
-            {"request": request,
-            "model_selection_options": model_selection_options,
-        })
+                                      {"request": request,
+                                       "model_selection_options": model_selection_options,
+                                       })
 
 
 
 
 ##############################################
 ##############################################
-#------------POST Request Routes--------------
+# ------------POST Request Routes--------------
 ##############################################
 ##############################################
 @app.post("/")
 @app.post("/")
 def detect_via_web_form(request: Request,
 def detect_via_web_form(request: Request,
                         file_list: List[UploadFile] = File(...),
                         file_list: List[UploadFile] = File(...),
                         model_name: str = Form(...),
                         model_name: str = Form(...),
                         img_size: int = Form(1824)):
                         img_size: int = Form(1824)):
-
     '''
     '''
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Intended for human (non-api) users.
     Intended for human (non-api) users.
     Returns: HTML template render showing bbox data and base64 encoded image
     Returns: HTML template render showing bbox data and base64 encoded image
     '''
     '''
 
 
-    #assume input validated properly if we got here
+    # assume input validated properly if we got here
     if model_dict[model_name] is None:
     if model_dict[model_name] is None:
         model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
         model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
 
 
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
-                    for file in file_list]
+                 for file in file_list]
 
 
-    #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
-    #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
+    # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
+    # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
 
-    results = model_dict[model_name](img_batch_rgb, size = img_size)
+    results = model_dict[model_name](img_batch_rgb, size=img_size)
 
 
-    json_results = results_to_json(results,model_dict[model_name])
+    json_results = results_to_json(results, model_dict[model_name])
 
 
     img_str_list = []
     img_str_list = []
-    #plot bboxes on the image
+    # plot bboxes on the image
     for img, bbox_list in zip(img_batch, json_results):
     for img, bbox_list in zip(img_batch, json_results):
         for bbox in bbox_list:
         for bbox in bbox_list:
             label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
             label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
             plot_one_box(bbox['bbox'], img, label=label,
             plot_one_box(bbox['bbox'], img, label=label,
-                    color=colors[int(bbox['class'])], line_thickness=3)
+                         color=colors[int(bbox['class'])], line_thickness=3)
 
 
         img_str_list.append(base64EncodeImage(img))
         img_str_list.append(base64EncodeImage(img))
 
 
-    #escape the apostrophes in the json string representation
-    encoded_json_results = str(json_results).replace("'",r"\'").replace('"',r'\"')
+    # escape the apostrophes in the json string representation
+    encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
 
 
     return templates.TemplateResponse('show_results.html', {
     return templates.TemplateResponse('show_results.html', {
-            'request': request,
-            'bbox_image_data_zipped': zip(img_str_list,json_results), #unzipped in jinja2 template
-            'bbox_data_str': encoded_json_results,
-        })
+        'request': request,
+        'bbox_image_data_zipped': zip(img_str_list, json_results),  # unzipped in jinja2 template
+        'bbox_data_str': encoded_json_results,
+    })
 
 
 
 
 @app.post("/detect")
 @app.post("/detect")
 @web_try()
 @web_try()
 def detect_via_api(request: Request,
 def detect_via_api(request: Request,
-                file_list: List[UploadFile] = File(...),
-                model_name: str = Form(...),
-                img_size: Optional[int] = Form(1920),
-                download_image: Optional[bool] = Form(False)):
-
+                   file_list: List[UploadFile] = File(...),
+                   model_name: str = Form(...),
+                   img_size: Optional[int] = Form(1920),
+                   download_image: Optional[bool] = Form(False)):
     '''
     '''
     Requires an image file upload, model name (ex. yolov5s).
     Requires an image file upload, model name (ex. yolov5s).
     Optional image size parameter (Default 1824)
     Optional image size parameter (Default 1824)
@@ -159,46 +158,47 @@ def detect_via_api(request: Request,
     '''
     '''
 
 
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
-                for file in file_list]
+                 for file in file_list]
 
 
-    #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
-    #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
+    # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
+    # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
 
-    results = model_dict[model_name](img_batch_rgb, size = img_size)
-    json_results = results_to_json(results,model_dict[model_name])
+    results = model_dict[model_name](img_batch_rgb, size=img_size)
+    json_results = results_to_json(results, model_dict[model_name])
 
 
     if download_image:
     if download_image:
         for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
         for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
             for bbox in bbox_list:
             for bbox in bbox_list:
                 label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
                 label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
                 plot_one_box(bbox['bbox'], img, label=label,
                 plot_one_box(bbox['bbox'], img, label=label,
-                        color=colors[int(bbox['class'])], line_thickness=3)
+                             color=colors[int(bbox['class'])], line_thickness=3)
 
 
-            payload = {'image_base64':base64EncodeImage(img)}
+            payload = {'image_base64': base64EncodeImage(img)}
             json_results[idx].append(payload)
             json_results[idx].append(payload)
 
 
-    encoded_json_results = str(json_results).replace("'",r'"')
+    encoded_json_results = str(json_results).replace("'", r'"')
     return encoded_json_results
     return encoded_json_results
 
 
+
 ##############################################
 ##############################################
-#--------------Helper Functions---------------
+# --------------Helper Functions---------------
 ##############################################
 ##############################################
 
 
 def results_to_json(results, model):
 def results_to_json(results, model):
     ''' Converts yolo model output to json (list of list of dicts)'''
     ''' Converts yolo model output to json (list of list of dicts)'''
     return [
     return [
-                [
-                    {
-                    "class": int(pred[5]),
-                    "class_name": model.model.names[int(pred[5])],
-                    "bbox": [int(x) for x in pred[:4].tolist()], #convert bbox results to int from float
-                    "confidence": float(pred[4]),
-                    }
-                for pred in result
-                ]
-            for result in results.xyxy
-            ]
+        [
+            {
+                "class": int(pred[5]),
+                "class_name": model.model.names[int(pred[5])],
+                "bbox": [int(x) for x in pred[:4].tolist()],  # convert bbox results to int from float
+                "confidence": float(pred[4]),
+            }
+            for pred in result
+        ]
+        for result in results.xyxy
+    ]
 
 
 
 
 def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
 def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
@@ -223,19 +223,27 @@ def base64EncodeImage(img):
 
 
     return im_b64
     return im_b64
 
 
+
+@app.get("/ping", description="健康检查")
+def ping():
+    logger.info("->ping")
+    return "pong!"
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     import uvicorn
     import uvicorn
     import argparse
     import argparse
+
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', default = 'localhost')
-    parser.add_argument('--port', default = 8080)
+    parser.add_argument('--host', default='localhost')
+    parser.add_argument('--port', default=8080)
     parser.add_argument('--precache-models', action='store_true',
     parser.add_argument('--precache-models', action='store_true',
-            help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
+                        help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
     opt = parser.parse_args()
     opt = parser.parse_args()
 
 
     # if opt.precache_models:
     # if opt.precache_models:
     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
     #                     for model_name in model_selection_options}
     #                     for model_name in model_selection_options}
 
 
-    app_str = 'server:app' #make the app string equal to whatever the name of this file is
-    uvicorn.run(app_str, host= opt.host, port=int(opt.port), reload=True)
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)