yolov7.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from pathlib import Path
  2. from typing import List
  3. from numpy import ndarray
  4. import torch
  5. from core.detectors.base import LayoutDetectorBase
  6. from core.layout import LayoutBox
  7. PROJ_ROOT = Path(__file__).parent.parent.parent
  8. YOLO_DIR = str(PROJ_ROOT / "yolov7")
  9. WEIGHTS = str(PROJ_ROOT / "yiliv7_718.pt")
  10. class Yolov7Detector(LayoutDetectorBase):
  11. print("======加载 YOLOv7 模型======")
  12. print(f"是否可使用GPU=======>{torch.cuda.is_available()}")
  13. _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. _model = torch.hub.load(YOLO_DIR, "custom", WEIGHTS, source="local").to(
  15. _device
  16. )
  17. print("========>模型加载成功")
  18. @classmethod
  19. def predict(cls, img: ndarray, img_size=1824, **kwargs) -> List[LayoutBox]:
  20. results = cls._model([img], size=img_size)
  21. return [
  22. [
  23. LayoutBox(
  24. clazz=int(pred[5]),
  25. clazz_name=cls._model.model.names[int(pred[5])],
  26. bbox=[
  27. int(x) for x in pred[:4].tolist()
  28. ], # convert bbox results to int from float
  29. conf=float(pred[4]),
  30. )
  31. for pred in result
  32. ]
  33. for result in results.xyxy
  34. ][0]