liangzhongquan há 2 anos atrás
pai
commit
3bf9369c9b
2 ficheiros alterados com 11 adições e 5 exclusões
  1. 4 1
      Makefile
  2. 7 4
      server.py

+ 4 - 1
Makefile

@@ -5,9 +5,12 @@ COMMIT_SHA1     := $(shell git rev-parse HEAD)
 AUTHOR          := $(shell git show -s --format='%an')
 
 
-.PHONY: all gpu
+.PHONY: all gpu cpu
 all: gpu
 gpu:
 	@docker build -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu --build-arg VERSION=gpu .
 	# @docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu
 
+cpu:
+	@docker build -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu --build-arg VERSION=cpu .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu

+ 7 - 4
server.py

@@ -14,7 +14,7 @@ YOLO_DIR = '/workspace/yolov5'
 # WEIGHTS = '/data/yolov5/runs/train/yolov5x_layout_reuslt37/weights/best.pt'
 WEIGHTS = '/workspace/best.pt'
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 print('====',torch.cuda.is_available())
 
 app = FastAPI()
@@ -25,6 +25,12 @@ model_dict = {model_name: None for model_name in model_selection_options} #set u
 
 colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
 
+
+if model_dict[model_name] is None:
+          model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
+          print("========>加载成功")
+
+
 ##############################################
 #-------------GET Request Routes--------------
 ##############################################
@@ -120,9 +126,6 @@ def detect_via_api(request: Request,
     Intended for API usage.
     '''
 
-    if model_dict[model_name] is None:
-          model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
-
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
                 for file in file_list]