Ver Fonte

修改内容:
logo检测模型添加服务

xujiayue há 1 ano atrás
pai
commit
c77f1bc8ec
1 ficheiros alterados com 7 adições e 3 exclusões
  1. 7 3
      server.py

+ 7 - 3
server.py

@@ -48,16 +48,20 @@ logger.info(f'是否可使用GPU=======>{bl}')
 app = FastAPI()
 templates = Jinja2Templates(directory = 'templates')
 
-model_selection_options = ['ocr-layout']
+model_selection_options = ['ocr-layout', 'ocr-logo']
 model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
 
 colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
 
 
 if model_dict['ocr-layout'] is None:
-          model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
-          logger.info("========>模型加载成功")
+    model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
+    logger.info("========>模型加载成功")
 
+# logo检测
+if model_dict['ocr-logo'] is None:
+    model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
+    logger.info("========>模型加载成功")
 
 ##############################################
 #-------------GET Request Routes--------------