林廷熠 1 rok temu
rodzic
commit
2fa551ba1f
4 zmienionych plików z 41 dodań i 37 usunięć
  1. 4 0
      .gitignore
  2. 11 0
      run.py
  3. 26 37
      server.py
  4. 0 0
      test.py

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+**/__pycache__/**
+.DS_Store
+
+*.pt

+ 11 - 0
run.py

@@ -0,0 +1,11 @@
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 26 - 37
server.py

@@ -12,13 +12,8 @@ import random
 import sys
 import logging
 
-YOLO_DIR = '/workspace/yolov5'
-# WEIGHTS = '/data/yolov5/runs/train/yolov5x_layout_reuslt37/weights/best.pt'
-# WEIGHTS = '/workspace/best.pt'
-# WEIGHTS = '/workspace/yili.pt'
-# WEIGHTS = '/workspace/best2.pt'
-
-WEIGHTS = '/workspace/11-17.pt'
+YOLO_DIR = './yolov7'
+WEIGHTS = './7_18.pt'
 
 logger = logging.getLogger('log')
 logger.setLevel(logging.DEBUG)
@@ -30,10 +25,10 @@ while logger.hasHandlers():
 
 # file log 写入文件配置
 formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # 日志的格式
-fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8')  # 日志文件路径文件名称,编码格式
-fh.setLevel(logging.DEBUG)  # 日志打印级别
-fh.setFormatter(formatter)
-logger.addHandler(fh)
+# fh = logging.FileHandler(r'./log/be.log', encoding='utf-8')  # 日志文件路径文件名称,编码格式
+# fh.setLevel(logging.DEBUG)  # 日志打印级别
+# fh.setFormatter(formatter)
+# logger.addHandler(fh)
 
 # console log 控制台输出控制
 ch = logging.StreamHandler(sys.stdout)
@@ -48,21 +43,15 @@ logger.info(f'是否可使用GPU=======>{bl}')
 app = FastAPI()
 templates = Jinja2Templates(directory='templates')
 
-model_selection_options = ['ocr-layout', 'ocr-logo']
+model_selection_options = ['ocr-layout']
 model_dict = {model_name: None for model_name in model_selection_options}  # set up model cache
 
 colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)]  # for bbox plotting
 
 if model_dict['ocr-layout'] is None:
-    model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
-    logger.info("========>模型加载成功")
-
-# logo检测
-if model_dict['ocr-logo'] is None:
-    model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
+    model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', WEIGHTS, source='local').to(device)
     logger.info("========>模型加载成功")
 
-
 ##############################################
 # -------------GET Request Routes--------------
 ##############################################
@@ -106,7 +95,7 @@ def detect_via_web_form(request: Request,
 
     # assume input validated properly if we got here
     if model_dict[model_name] is None:
-        model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
+        model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', WEIGHTS, source='local').to(device)
 
     img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
                  for file in file_list]
@@ -230,20 +219,20 @@ def ping():
     return "pong!"
 
 
-if __name__ == '__main__':
-    import uvicorn
-    import argparse
-
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--host', default='localhost')
-    parser.add_argument('--port', default=8080)
-    parser.add_argument('--precache-models', action='store_true',
-                        help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
-    opt = parser.parse_args()
-
-    # if opt.precache_models:
-    #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
-    #                     for model_name in model_selection_options}
-
-    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
-    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
+# if __name__ == '__main__':
+#     import uvicorn
+#     import argparse
+#
+#     parser = argparse.ArgumentParser()
+#     parser.add_argument('--host', default='localhost')
+#     parser.add_argument('--port', default=8080)
+#     parser.add_argument('--precache-models', action='store_true',
+#                         help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
+#     opt = parser.parse_args()
+#
+#     # if opt.precache_models:
+#     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
+#     #                     for model_name in model_selection_options}
+#
+#     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+#     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 0 - 0
test.py