|
@@ -48,16 +48,20 @@ logger.info(f'是否可使用GPU=======>{bl}')
|
|
|
app = FastAPI()
|
|
|
templates = Jinja2Templates(directory = 'templates')
|
|
|
|
|
|
-model_selection_options = ['ocr-layout']
|
|
|
+model_selection_options = ['ocr-layout', 'ocr-logo']
|
|
|
model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
|
|
|
|
|
|
colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
|
|
|
|
|
|
|
|
|
if model_dict['ocr-layout'] is None:
|
|
|
- model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
- logger.info("========>模型加载成功")
|
|
|
+ model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
+ logger.info("========>模型加载成功")
|
|
|
|
|
|
+# logo检测
|
|
|
+if model_dict['ocr-logo'] is None:
|
|
|
+ model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
|
|
|
+ logger.info("========>模型加载成功")
|
|
|
|
|
|
##############################################
|
|
|
#-------------GET Request Routes--------------
|