|
@@ -14,7 +14,7 @@ YOLO_DIR = '/workspace/yolov5'
|
|
|
# WEIGHTS = '/data/yolov5/runs/train/yolov5x_layout_reuslt37/weights/best.pt'
|
|
|
WEIGHTS = '/workspace/best.pt'
|
|
|
|
|
|
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
print('====',torch.cuda.is_available())
|
|
|
|
|
|
app = FastAPI()
|
|
@@ -25,6 +25,12 @@ model_dict = {model_name: None for model_name in model_selection_options} #set u
|
|
|
|
|
|
colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
|
|
|
|
|
|
+
|
|
|
+if model_dict[model_name] is None:
|
|
|
+ model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
+ print("========>加载成功")
|
|
|
+
|
|
|
+
|
|
|
##############################################
|
|
|
#-------------GET Request Routes--------------
|
|
|
##############################################
|
|
@@ -120,9 +126,6 @@ def detect_via_api(request: Request,
|
|
|
Intended for API usage.
|
|
|
'''
|
|
|
|
|
|
- if model_dict[model_name] is None:
|
|
|
- model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
|
|
|
-
|
|
|
img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
|
|
|
for file in file_list]
|
|
|
|