3 Commity 37c228510d ... b8d49d5a45

Autor SHA1 Wiadomość Data
  jingze_cheng b8d49d5a45 chore: 删除临时输出 8 miesięcy temu
  jingze_cheng 43d208f007 feat: 支持多尺度推理及后处理 8 miesięcy temu
  jingze_cheng cfb4a14b31 fix: 裁剪bbox以防止超出图像边界 8 miesięcy temu

+ 5 - 58
core/detectors/paddle_yolo/__init__.py

@@ -1,4 +1,3 @@
-import itertools
 import threading
 from typing import List
 from numpy import ndarray
@@ -63,14 +62,7 @@ class PaddleYoloDetector(LayoutDetectorBase):
     lock = threading.Lock()
 
     @classmethod
-    def predict(
-        cls,
-        img: ndarray,
-        conf_threshold: float = 0.0,
-        overlaps_iou_threshold: float = 0.9,
-        overlaps_max_count: int = 5,
-        **kwargs
-    ) -> List[LayoutBox]:
+    def predict(cls, img: ndarray, **kwargs) -> List[LayoutBox]:
         try:
             cls.lock.acquire()
             predicts = cls._detector.predict_image(
@@ -79,62 +71,17 @@ class PaddleYoloDetector(LayoutDetectorBase):
             results = []
             boxes = predicts["boxes"].tolist()  # type: ignore
             for box in boxes:
-                print(box)
                 clazz = int(box[0])
                 conf = box[1]
                 bbox = box[2:6]
                 results.append(
                     LayoutBox(
-                        clazz=clazz, clazz_name=None, bbox=bbox, conf=conf
+                        clazz=clazz,
+                        clazz_name=None,
+                        bbox=bbox,
+                        conf=conf,
                     )
                 )
-            results = _filter_by_conf(results, conf_threshold)
-            results = _filter_by_overlaps(
-                results, overlaps_iou_threshold, overlaps_max_count
-            )
         finally:
             cls.lock.release()
         return results
-
-
-def _filter_by_conf(
-    boxes: List[LayoutBox],
-    conf_threshold: float,
-) -> List[LayoutBox]:
-    # 按置信度过滤LayoutBox
-    boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
-    return boxes
-
-
-def _filter_by_overlaps(
-    boxes: List[LayoutBox],
-    overlaps_iou_threshold: float,
-    overlaps_max_count: int,
-) -> List[LayoutBox]:
-    """
-    对多个 iou 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
-    """
-    # 按置信度进行排序
-    boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
-    # 每一个桶中都是重叠区域较大的LayoutBox
-    buckets: List[List[LayoutBox]] = []
-
-    # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
-    def get_bucket(box: LayoutBox):
-        for bucket in buckets:
-            for e in bucket:
-                if box.iou(e) >= overlaps_iou_threshold:
-                    return bucket
-        return None
-
-    for layout in boxes:
-        bucket = get_bucket(layout)
-        # 若当前不存在于目标layout重叠的内容,则新建一个桶
-        if not bucket:
-            buckets.append([layout])
-        # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
-        elif len(bucket) < overlaps_max_count:
-            bucket.append(layout)
-    # 将所用桶中的数据合为一个列表
-    new_layouts = list(itertools.chain.from_iterable(buckets))
-    return new_layouts

+ 1 - 1
core/detectors/yolov7.py

@@ -3,8 +3,8 @@ from typing import List
 from numpy import ndarray
 import torch
 
-from core.predictor import LayoutBox
 from core.detectors.base import LayoutDetectorBase
+from core.layout import LayoutBox
 
 PROJ_ROOT = Path(__file__).parent.parent.parent
 YOLO_DIR = str(PROJ_ROOT / "yolov7")

+ 79 - 0
core/layout.py

@@ -0,0 +1,79 @@
+from __future__ import annotations
+from typing import Dict, List, Optional
+
+import numpy as np
+
+
+class LayoutBox:
+    def __init__(
+        self,
+        clazz: int,
+        bbox: List[int],
+        conf: float,
+        clazz_name: Optional[str] = None,
+    ):
+        self.clazz = clazz
+        self.clazz_name = clazz_name
+        self.bbox = bbox
+        self.conf = conf
+
+    @property
+    def ltrb(self):
+        l, t, r, b = self.bbox
+        return [int(x) for x in [l, t, r, b]]
+
+    @property
+    def area(self):
+        l, t, r, b = self.ltrb
+        return (r - l + 1) * (b - t + 1)
+
+    def iou(self, other):
+        boxA = self.ltrb
+        boxB = other.ltrb
+        boxA = [int(x) for x in boxA]
+        boxB = [int(x) for x in boxB]
+
+        xA = max(boxA[0], boxB[0])
+        yA = max(boxA[1], boxB[1])
+        xB = min(boxA[2], boxB[2])
+        yB = min(boxA[3], boxB[3])
+
+        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
+
+        boxAArea = self.area
+        boxBArea = other.area
+
+        iou = interArea / float(boxAArea + boxBArea - interArea)
+
+        return iou
+
+    def to_dict(self) -> Dict:
+        return {
+            "class": self.clazz,
+            "class_name": self.clazz_name,
+            "bbox": self.bbox,
+            "confidence": self.conf,
+        }
+
+    def to_service_dict(self) -> Dict:
+        """
+        返回中间服务所需的格式
+        """
+        return {
+            "class": self.clazz,
+            "class_name": self.clazz_name,
+            "bbox": self.ltrb,
+            "confidence": self.conf,
+        }
+
+    @classmethod
+    def from_dict(cls, d) -> LayoutBox:
+        return cls(
+            clazz=d["class"],
+            clazz_name=d["class_name"],
+            bbox=d["bbox"],
+            conf=d["confidence"],
+        )
+
+    def __repr__(self):
+        return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"

+ 159 - 0
core/post_process.py

@@ -0,0 +1,159 @@
+import itertools
+from typing import List
+
+from core.layout import LayoutBox
+
+
+def merge_boxes_list(
+    boxes_list: List[List[LayoutBox]],
+    img_w: int,
+    img_h: int,
+    method: str = "nms",
+    iou_threshold=0.5,
+) -> List[LayoutBox]:
+    """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。
+    可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。
+    See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
+
+    method
+    Args:
+        boxes_list (List[List[LayoutBox]]):
+            多组检测框列表
+        img_w (int):
+            图像宽度
+        img_h (int):
+            图像高度
+        method (str, optional):
+            合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms".
+        iou_threshold (float, optional):
+            bbox 匹配的 IoU 阈值. Defaults to 0.5.
+
+    Returns:
+        List[LayoutBox]: 合并后的检测框列表
+    """
+
+    def ltrb_to_nltrb(ltrb, img_w, img_h):
+        """
+        Normalize ltrb.
+        """
+        l, t, r, b = ltrb
+        nl = l / img_w
+        nt = t / img_h
+        nr = r / img_w
+        nb = b / img_h
+        return [nl, nt, nr, nb]
+
+    def nltrb_to_ltrb(nltrb, img_w, img_h):
+        """
+        Denormalize normalized ltrb.
+        """
+        nl, nt, nr, nb = nltrb
+        l = nl * img_w
+        t = nt * img_h
+        r = nr * img_w
+        b = nb * img_h
+        return [l, t, r, b]
+
+    from ensemble_boxes import (
+        nms,
+        soft_nms,
+        non_maximum_weighted,
+        weighted_boxes_fusion,
+    )
+
+    merge_funcs = {
+        "nms": nms,
+        "soft_nms": soft_nms,
+        "nmw": non_maximum_weighted,
+        "wbf": weighted_boxes_fusion,
+    }
+    assert method in merge_funcs.keys()
+    merge_func = merge_funcs[method]
+
+    nltrbs_list = [
+        [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes]
+        for boxes in boxes_list
+    ]
+    scores_list = [[b.conf for b in boxes] for boxes in boxes_list]
+    labels_list = [[b.clazz for b in boxes] for boxes in boxes_list]
+
+    nltrbs, scores, labels = merge_func(
+        nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold
+    )
+
+    merged_boxes = [
+        LayoutBox(
+            clazz=int(label),
+            bbox=nltrb_to_ltrb(nltrb, img_w, img_h),
+            conf=float(score),
+        )
+        for nltrb, score, label in zip(nltrbs, scores, labels)
+    ]
+    return merged_boxes
+
+
+def clip_boxes_to_image_bound(
+    boxes: List[LayoutBox], img_w: int, img_h: int
+) -> List[LayoutBox]:
+    """
+    裁剪检测框尺寸以防止超出图像边界。
+    """
+
+    def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]:
+        l, t, r, b = bbox
+        l = max(0, int(l))
+        t = max(0, int(t))
+        r = min(img_w, int(r))
+        b = min(img_h, int(b))
+        return [l, t, r, b]
+
+    for box in boxes:
+        box.bbox = clip_bbox(box.bbox, img_w, img_h)
+
+    return boxes
+
+
+def filter_boxes_by_conf(
+    boxes: List[LayoutBox],
+    conf_threshold: float,
+) -> List[LayoutBox]:
+    """
+    按置信度过滤检测框。
+    """
+    boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
+    return boxes
+
+
+def filter_boxes_by_overlaps(
+    boxes: List[LayoutBox],
+    overlaps_iou_threshold: float,
+    overlaps_max_count: int,
+) -> List[LayoutBox]:
+    """
+    按置信度和 IoU 过滤检测框。
+    对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
+    """
+    # 按置信度进行排序
+    boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
+    # 每一个桶中都是重叠区域较大的LayoutBox
+    buckets: List[List[LayoutBox]] = []
+
+    # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
+    def get_bucket(box: LayoutBox):
+        for bucket in buckets:
+            for e in bucket:
+                if box.iou(e) >= overlaps_iou_threshold:
+                    return bucket
+        return None
+
+    for box in boxes:
+        bucket = get_bucket(box)
+        # 若当前不存在于目标layout重叠的内容,则新建一个桶
+        if not bucket:
+            buckets.append([box])
+        # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
+        elif len(bucket) < overlaps_max_count:
+            bucket.append(box)
+    # 将所用桶中的数据合为一个列表
+    new_boxes = list(itertools.chain.from_iterable(buckets))
+    return new_boxes

+ 40 - 84
core/predictor.py

@@ -1,11 +1,23 @@
 from __future__ import annotations
-from typing import Dict, List, Optional
+from typing import List
 
 import numpy as np
 
+from core.layout import LayoutBox
+from core.post_process import (
+    clip_boxes_to_image_bound,
+    filter_boxes_by_conf,
+    filter_boxes_by_overlaps,
+    merge_boxes_list,
+)
+
 
 def predict_img(
-    img: np.ndarray, model_name: str, img_size: int, **kwargs
+    img: np.ndarray,
+    model_name: str,
+    img_size: int,
+    multi_scale: bool,
+    **kwargs,
 ) -> List[LayoutBox]:
     """
     根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
@@ -13,92 +25,36 @@ def predict_img(
     if model_name == "ocr-layout":
         from core.detectors.yolov7 import Yolov7Detector
 
-        return Yolov7Detector.predict(img, img_size, **kwargs)
+        if not multi_scale:
+            return Yolov7Detector.predict(img, img_size, **kwargs)
 
-    elif model_name == "ocr-layout-paddle":
-        from core.detectors.paddle_yolo import PaddleYoloDetector
-
-        return PaddleYoloDetector.predict(
-            img,
-            conf_threshold=0.3,
-            overlaps_iou_threshold=0.85,
-            overlaps_max_count=3,
-            **kwargs,
+        scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5]
+        img_sizes = [int(img_size * factor) for factor in scale_factors]
+        boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes]
+        img_h, img_w = img.shape[:2]
+        boxes = merge_boxes_list(
+            boxes_list, img_w, img_h, method="nms", iou_threshold=0.55
         )
-    else:
-        raise RuntimeError(f"Invalid model name: {model_name}")
-
-
-class LayoutBox:
-    def __init__(
-        self,
-        clazz: int,
-        clazz_name: Optional[str],
-        bbox: List[int],
-        conf: float,
-    ):
-        self.clazz = clazz
-        self.clazz_name = clazz_name
-        self.bbox = bbox
-        self.conf = conf
-
-    @property
-    def ltrb(self):
-        l, t, r, b = self.bbox
-        return [int(x) for x in [l, t, r, b]]
-
-    @property
-    def area(self):
-        l, t, r, b = self.ltrb
-        return (r - l + 1) * (b - t + 1)
+        return boxes
 
-    def iou(self, other):
-        boxA = self.ltrb
-        boxB = other.ltrb
-        boxA = [int(x) for x in boxA]
-        boxB = [int(x) for x in boxB]
-
-        xA = max(boxA[0], boxB[0])
-        yA = max(boxA[1], boxB[1])
-        xB = min(boxA[2], boxB[2])
-        yB = min(boxA[3], boxB[3])
-
-        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
-
-        boxAArea = self.area
-        boxBArea = other.area
-
-        iou = interArea / float(boxAArea + boxBArea - interArea)
-
-        return iou
+    elif model_name == "ocr-layout-paddle":
+        from core.detectors.paddle_yolo import PaddleYoloDetector
 
-    def to_dict(self) -> Dict:
-        return {
-            "class": self.clazz,
-            "class_name": self.clazz_name,
-            "bbox": self.bbox,
-            "confidence": self.conf,
-        }
+        boxes = PaddleYoloDetector.predict(img, **kwargs)
 
-    def to_service_dict(self) -> Dict:
-        """
-        返回中间服务所需的格式
-        """
-        return {
-            "class": self.clazz,
-            "class_name": self.clazz_name,
-            "bbox": self.ltrb,
-            "confidence": self.conf,
-        }
+        # 该模型预测的类型暂时需要重新映射
+        # FIXME: 统一类型
+        _clazz_remap = {11: 10, 12: 11}
+        for b in boxes:
+            if b.clazz in _clazz_remap:
+                b.clazz = _clazz_remap[b.clazz]
 
-    @classmethod
-    def from_dict(cls, d) -> LayoutBox:
-        return cls(
-            clazz=d["class"],
-            clazz_name=d["class_name"],
-            bbox=d["bbox"],
-            conf=d["confidence"],
+        img_h, img_w = img.shape[:2]
+        boxes = clip_boxes_to_image_bound(boxes, img_w, img_h)
+        boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1)
+        boxes = filter_boxes_by_overlaps(
+            boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3
         )
-
-    def __repr__(self):
-        return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"
+        return boxes
+    else:
+        raise RuntimeError(f"Invalid model name: {model_name}")

+ 8 - 24
server.py

@@ -7,7 +7,8 @@ import cv2
 import numpy as np
 
 import base64
-from core.predictor import LayoutBox, predict_img
+from core.predictor import predict_img
+from core.layout import LayoutBox
 from sx_utils import format_print
 
 app = FastAPI()
@@ -31,7 +32,6 @@ clazz_names = [
     "style",
     "table",
     "text",
-    "text_jhl",
     "title",
 ]
 
@@ -69,7 +69,9 @@ def drag_and_drop_detect(request: Request):
 def detect_via_web_form(request: Request,
                         file_list: List[UploadFile] = File(...),
                         model_name: str = Form(...),
-                        img_size: int = Form(1824)):
+                        img_size: int = Form(1824),
+                        multi_scale: bool = Form(False),
+                        ):
     '''
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Intended for human (non-api) users.
@@ -83,7 +85,7 @@ def detect_via_web_form(request: Request,
     # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
-    results = [predict_img(img, model_name, img_size) for img in img_batch_rgb]
+    results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
     
     json_results = boxes_list_to_json(results, clazz_names)
 
@@ -113,6 +115,7 @@ def detect_via_api(request: Request,
                    file_list: List[UploadFile] = File(...),
                    model_name: str = Form(...),
                    img_size: int = Form(1920),
+                   multi_scale: bool = Form(False),
                    download_image: Optional[bool] = Form(False)):
     '''
     Requires an image file upload, model name (ex. yolov5s).
@@ -133,7 +136,7 @@ def detect_via_api(request: Request,
     # 转换图片格式
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     # 选用相关模型进行模版识别
-    results = [predict_img(img, model_name, img_size) for img in img_batch_rgb]
+    results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
     # 处理结果数据
     json_results = boxes_list_to_json(results, clazz_names)
 
@@ -212,22 +215,3 @@ def boxes_list_to_json(boxes_list: List[List[LayoutBox]], clazz_names: List[str]
 def ping():
     print("->ping")
     return "pong!"
-
-
-# if __name__ == '__main__':
-#     import uvicorn
-#     import argparse
-#
-#     parser = argparse.ArgumentParser()
-#     parser.add_argument('--host', default='localhost')
-#     parser.add_argument('--port', default=8080)
-#     parser.add_argument('--precache-models', action='store_true',
-#                         help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
-#     opt = parser.parse_args()
-#
-#     # if opt.precache_models:
-#     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
-#     #                     for model_name in model_selection_options}
-#
-#     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
-#     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 9 - 0
templates/drag_and_drop_detect.html

@@ -43,6 +43,14 @@ TODO:
         name="img_size"
         value="1824"
       />
+      <input
+        type="checkbox"
+        id="multi_scale"
+        name="multi_scale"
+      />
+      <label for="multi_scale" class="form-label">
+        <b>Multi-scale Inference</b>
+      </label>
     </div>
 
     <div class="col">
@@ -329,6 +337,7 @@ TODO:
     formData.append('file_list', img)
     formData.append('model_name', $('#model_name').val())
     formData.append('img_size', $('#img_size').val())
+    formData.append('multi_scale', $('#multi_scale').val())
 
     $.ajax({
       url: '/detect',

+ 10 - 0
templates/home.html

@@ -43,6 +43,16 @@ Implements a simple Bootstrap 5 form submission interface for YOLOv5 detection.
         value="1824"
       />
     </div>
+    <div class="my-2">
+      <input
+        type="checkbox"
+        id="multi_scale"
+        name="multi_scale"
+      />
+      <label for="multi_scale" class="form-label">
+        Multi-scale Inference
+      </label>
+    </div>
 
     <button class="btn btn-primary" type="submit">Submit</button>
   </form>