瀏覽代碼

fix: 裁剪bbox以防止超出图像边界

jingze_cheng 7 月之前
父節點
當前提交
cfb4a14b31
共有 3 個文件被更改,包括 23 次插入3 次删除
  1. 1 1
      core/detectors/paddle_yolo/__init__.py
  2. 1 0
      core/detectors/yolov7.py
  3. 21 2
      core/predictor.py

+ 1 - 1
core/detectors/paddle_yolo/__init__.py

@@ -85,7 +85,7 @@ class PaddleYoloDetector(LayoutDetectorBase):
                 bbox = box[2:6]
                 results.append(
                     LayoutBox(
-                        clazz=clazz, clazz_name=None, bbox=bbox, conf=conf
+                        clazz=clazz, clazz_name=None, bbox=bbox, conf=conf, full_img=img,
                     )
                 )
             results = _filter_by_conf(results, conf_threshold)

+ 1 - 0
core/detectors/yolov7.py

@@ -32,6 +32,7 @@ class Yolov7Detector(LayoutDetectorBase):
                         int(x) for x in pred[:4].tolist()
                     ],  # convert bbox results to int from float
                     conf=float(pred[4]),
+                    full_img=img,
                 )
                 for pred in result
             ]

+ 21 - 2
core/predictor.py

@@ -20,7 +20,7 @@ def predict_img(
 
         return PaddleYoloDetector.predict(
             img,
-            conf_threshold=0.3,
+            conf_threshold=0.1,
             overlaps_iou_threshold=0.85,
             overlaps_max_count=3,
             **kwargs,
@@ -36,11 +36,13 @@ class LayoutBox:
         clazz_name: Optional[str],
         bbox: List[int],
         conf: float,
+        full_img: Optional[np.ndarray]=None,
     ):
         self.clazz = clazz
         self.clazz_name = clazz_name
         self.bbox = bbox
         self.conf = conf
+        self.full_img = full_img
 
     @property
     def ltrb(self):
@@ -84,10 +86,15 @@ class LayoutBox:
         """
         返回中间服务所需的格式
         """
+        if isinstance(self.full_img, np.ndarray):
+            h, w = self.full_img.shape[:2]
+            bbox = self.clip_bbox(self.bbox, h, w)
+        else:
+            bbox = self.ltrb
         return {
             "class": self.clazz,
             "class_name": self.clazz_name,
-            "bbox": self.ltrb,
+            "bbox": bbox,
             "confidence": self.conf,
         }
 
@@ -100,5 +107,17 @@ class LayoutBox:
             conf=d["confidence"],
         )
 
+    @staticmethod
+    def clip_bbox(bbox: List[int], img_h: int, img_w: int) -> List[int]:
+        """
+        裁剪 bbox 尺寸以防止超出图像边界。
+        """
+        l, t, r, b = bbox
+        l = max(0, int(l))
+        t = max(0, int(t))
+        r = min(img_w, int(r))
+        b = min(img_h, int(b))
+        return [l, t, r, b]
+
     def __repr__(self):
         return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"