3 Commits 37c228510d ... b8d49d5a45

Author SHA1 Message Date
  jingze_cheng b8d49d5a45 chore: 删除临时输出 8 months ago
  jingze_cheng 43d208f007 feat: 支持多尺度推理及后处理 8 months ago
  jingze_cheng cfb4a14b31 fix: 裁剪bbox以防止超出图像边界 8 months ago

+ 5 - 58
core/detectors/paddle_yolo/__init__.py

@@ -1,4 +1,3 @@
-import itertools
 import threading
 import threading
 from typing import List
 from typing import List
 from numpy import ndarray
 from numpy import ndarray
@@ -63,14 +62,7 @@ class PaddleYoloDetector(LayoutDetectorBase):
     lock = threading.Lock()
     lock = threading.Lock()
 
 
     @classmethod
     @classmethod
-    def predict(
-        cls,
-        img: ndarray,
-        conf_threshold: float = 0.0,
-        overlaps_iou_threshold: float = 0.9,
-        overlaps_max_count: int = 5,
-        **kwargs
-    ) -> List[LayoutBox]:
+    def predict(cls, img: ndarray, **kwargs) -> List[LayoutBox]:
         try:
         try:
             cls.lock.acquire()
             cls.lock.acquire()
             predicts = cls._detector.predict_image(
             predicts = cls._detector.predict_image(
@@ -79,62 +71,17 @@ class PaddleYoloDetector(LayoutDetectorBase):
             results = []
             results = []
             boxes = predicts["boxes"].tolist()  # type: ignore
             boxes = predicts["boxes"].tolist()  # type: ignore
             for box in boxes:
             for box in boxes:
-                print(box)
                 clazz = int(box[0])
                 clazz = int(box[0])
                 conf = box[1]
                 conf = box[1]
                 bbox = box[2:6]
                 bbox = box[2:6]
                 results.append(
                 results.append(
                     LayoutBox(
                     LayoutBox(
-                        clazz=clazz, clazz_name=None, bbox=bbox, conf=conf
+                        clazz=clazz,
+                        clazz_name=None,
+                        bbox=bbox,
+                        conf=conf,
                     )
                     )
                 )
                 )
-            results = _filter_by_conf(results, conf_threshold)
-            results = _filter_by_overlaps(
-                results, overlaps_iou_threshold, overlaps_max_count
-            )
         finally:
         finally:
             cls.lock.release()
             cls.lock.release()
         return results
         return results
-
-
-def _filter_by_conf(
-    boxes: List[LayoutBox],
-    conf_threshold: float,
-) -> List[LayoutBox]:
-    # 按置信度过滤LayoutBox
-    boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
-    return boxes
-
-
-def _filter_by_overlaps(
-    boxes: List[LayoutBox],
-    overlaps_iou_threshold: float,
-    overlaps_max_count: int,
-) -> List[LayoutBox]:
-    """
-    对多个 iou 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
-    """
-    # 按置信度进行排序
-    boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
-    # 每一个桶中都是重叠区域较大的LayoutBox
-    buckets: List[List[LayoutBox]] = []
-
-    # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
-    def get_bucket(box: LayoutBox):
-        for bucket in buckets:
-            for e in bucket:
-                if box.iou(e) >= overlaps_iou_threshold:
-                    return bucket
-        return None
-
-    for layout in boxes:
-        bucket = get_bucket(layout)
-        # 若当前不存在于目标layout重叠的内容,则新建一个桶
-        if not bucket:
-            buckets.append([layout])
-        # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
-        elif len(bucket) < overlaps_max_count:
-            bucket.append(layout)
-    # 将所用桶中的数据合为一个列表
-    new_layouts = list(itertools.chain.from_iterable(buckets))
-    return new_layouts

+ 1 - 1
core/detectors/yolov7.py

@@ -3,8 +3,8 @@ from typing import List
 from numpy import ndarray
 from numpy import ndarray
 import torch
 import torch
 
 
-from core.predictor import LayoutBox
 from core.detectors.base import LayoutDetectorBase
 from core.detectors.base import LayoutDetectorBase
+from core.layout import LayoutBox
 
 
 PROJ_ROOT = Path(__file__).parent.parent.parent
 PROJ_ROOT = Path(__file__).parent.parent.parent
 YOLO_DIR = str(PROJ_ROOT / "yolov7")
 YOLO_DIR = str(PROJ_ROOT / "yolov7")

+ 79 - 0
core/layout.py

@@ -0,0 +1,79 @@
+from __future__ import annotations
+from typing import Dict, List, Optional
+
+import numpy as np
+
+
+class LayoutBox:
+    def __init__(
+        self,
+        clazz: int,
+        bbox: List[int],
+        conf: float,
+        clazz_name: Optional[str] = None,
+    ):
+        self.clazz = clazz
+        self.clazz_name = clazz_name
+        self.bbox = bbox
+        self.conf = conf
+
+    @property
+    def ltrb(self):
+        l, t, r, b = self.bbox
+        return [int(x) for x in [l, t, r, b]]
+
+    @property
+    def area(self):
+        l, t, r, b = self.ltrb
+        return (r - l + 1) * (b - t + 1)
+
+    def iou(self, other):
+        boxA = self.ltrb
+        boxB = other.ltrb
+        boxA = [int(x) for x in boxA]
+        boxB = [int(x) for x in boxB]
+
+        xA = max(boxA[0], boxB[0])
+        yA = max(boxA[1], boxB[1])
+        xB = min(boxA[2], boxB[2])
+        yB = min(boxA[3], boxB[3])
+
+        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
+
+        boxAArea = self.area
+        boxBArea = other.area
+
+        iou = interArea / float(boxAArea + boxBArea - interArea)
+
+        return iou
+
+    def to_dict(self) -> Dict:
+        return {
+            "class": self.clazz,
+            "class_name": self.clazz_name,
+            "bbox": self.bbox,
+            "confidence": self.conf,
+        }
+
+    def to_service_dict(self) -> Dict:
+        """
+        返回中间服务所需的格式
+        """
+        return {
+            "class": self.clazz,
+            "class_name": self.clazz_name,
+            "bbox": self.ltrb,
+            "confidence": self.conf,
+        }
+
+    @classmethod
+    def from_dict(cls, d) -> LayoutBox:
+        return cls(
+            clazz=d["class"],
+            clazz_name=d["class_name"],
+            bbox=d["bbox"],
+            conf=d["confidence"],
+        )
+
+    def __repr__(self):
+        return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"

+ 159 - 0
core/post_process.py

@@ -0,0 +1,159 @@
+import itertools
+from typing import List
+
+from core.layout import LayoutBox
+
+
+def merge_boxes_list(
+    boxes_list: List[List[LayoutBox]],
+    img_w: int,
+    img_h: int,
+    method: str = "nms",
+    iou_threshold=0.5,
+) -> List[LayoutBox]:
+    """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。
+    可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。
+    See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
+
+    method
+    Args:
+        boxes_list (List[List[LayoutBox]]):
+            多组检测框列表
+        img_w (int):
+            图像宽度
+        img_h (int):
+            图像高度
+        method (str, optional):
+            合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms".
+        iou_threshold (float, optional):
+            bbox 匹配的 IoU 阈值. Defaults to 0.5.
+
+    Returns:
+        List[LayoutBox]: 合并后的检测框列表
+    """
+
+    def ltrb_to_nltrb(ltrb, img_w, img_h):
+        """
+        Normalize ltrb.
+        """
+        l, t, r, b = ltrb
+        nl = l / img_w
+        nt = t / img_h
+        nr = r / img_w
+        nb = b / img_h
+        return [nl, nt, nr, nb]
+
+    def nltrb_to_ltrb(nltrb, img_w, img_h):
+        """
+        Denormalize normalized ltrb.
+        """
+        nl, nt, nr, nb = nltrb
+        l = nl * img_w
+        t = nt * img_h
+        r = nr * img_w
+        b = nb * img_h
+        return [l, t, r, b]
+
+    from ensemble_boxes import (
+        nms,
+        soft_nms,
+        non_maximum_weighted,
+        weighted_boxes_fusion,
+    )
+
+    merge_funcs = {
+        "nms": nms,
+        "soft_nms": soft_nms,
+        "nmw": non_maximum_weighted,
+        "wbf": weighted_boxes_fusion,
+    }
+    assert method in merge_funcs.keys()
+    merge_func = merge_funcs[method]
+
+    nltrbs_list = [
+        [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes]
+        for boxes in boxes_list
+    ]
+    scores_list = [[b.conf for b in boxes] for boxes in boxes_list]
+    labels_list = [[b.clazz for b in boxes] for boxes in boxes_list]
+
+    nltrbs, scores, labels = merge_func(
+        nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold
+    )
+
+    merged_boxes = [
+        LayoutBox(
+            clazz=int(label),
+            bbox=nltrb_to_ltrb(nltrb, img_w, img_h),
+            conf=float(score),
+        )
+        for nltrb, score, label in zip(nltrbs, scores, labels)
+    ]
+    return merged_boxes
+
+
+def clip_boxes_to_image_bound(
+    boxes: List[LayoutBox], img_w: int, img_h: int
+) -> List[LayoutBox]:
+    """
+    裁剪检测框尺寸以防止超出图像边界。
+    """
+
+    def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]:
+        l, t, r, b = bbox
+        l = max(0, int(l))
+        t = max(0, int(t))
+        r = min(img_w, int(r))
+        b = min(img_h, int(b))
+        return [l, t, r, b]
+
+    for box in boxes:
+        box.bbox = clip_bbox(box.bbox, img_w, img_h)
+
+    return boxes
+
+
+def filter_boxes_by_conf(
+    boxes: List[LayoutBox],
+    conf_threshold: float,
+) -> List[LayoutBox]:
+    """
+    按置信度过滤检测框。
+    """
+    boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
+    return boxes
+
+
+def filter_boxes_by_overlaps(
+    boxes: List[LayoutBox],
+    overlaps_iou_threshold: float,
+    overlaps_max_count: int,
+) -> List[LayoutBox]:
+    """
+    按置信度和 IoU 过滤检测框。
+    对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
+    """
+    # 按置信度进行排序
+    boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
+    # 每一个桶中都是重叠区域较大的LayoutBox
+    buckets: List[List[LayoutBox]] = []
+
+    # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
+    def get_bucket(box: LayoutBox):
+        for bucket in buckets:
+            for e in bucket:
+                if box.iou(e) >= overlaps_iou_threshold:
+                    return bucket
+        return None
+
+    for box in boxes:
+        bucket = get_bucket(box)
+        # 若当前不存在于目标layout重叠的内容,则新建一个桶
+        if not bucket:
+            buckets.append([box])
+        # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
+        elif len(bucket) < overlaps_max_count:
+            bucket.append(box)
+    # 将所用桶中的数据合为一个列表
+    new_boxes = list(itertools.chain.from_iterable(buckets))
+    return new_boxes

+ 40 - 84
core/predictor.py

@@ -1,11 +1,23 @@
 from __future__ import annotations
 from __future__ import annotations
-from typing import Dict, List, Optional
+from typing import List
 
 
 import numpy as np
 import numpy as np
 
 
+from core.layout import LayoutBox
+from core.post_process import (
+    clip_boxes_to_image_bound,
+    filter_boxes_by_conf,
+    filter_boxes_by_overlaps,
+    merge_boxes_list,
+)
+
 
 
 def predict_img(
 def predict_img(
-    img: np.ndarray, model_name: str, img_size: int, **kwargs
+    img: np.ndarray,
+    model_name: str,
+    img_size: int,
+    multi_scale: bool,
+    **kwargs,
 ) -> List[LayoutBox]:
 ) -> List[LayoutBox]:
     """
     """
     根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
     根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
@@ -13,92 +25,36 @@ def predict_img(
     if model_name == "ocr-layout":
     if model_name == "ocr-layout":
         from core.detectors.yolov7 import Yolov7Detector
         from core.detectors.yolov7 import Yolov7Detector
 
 
-        return Yolov7Detector.predict(img, img_size, **kwargs)
+        if not multi_scale:
+            return Yolov7Detector.predict(img, img_size, **kwargs)
 
 
-    elif model_name == "ocr-layout-paddle":
-        from core.detectors.paddle_yolo import PaddleYoloDetector
-
-        return PaddleYoloDetector.predict(
-            img,
-            conf_threshold=0.3,
-            overlaps_iou_threshold=0.85,
-            overlaps_max_count=3,
-            **kwargs,
+        scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5]
+        img_sizes = [int(img_size * factor) for factor in scale_factors]
+        boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes]
+        img_h, img_w = img.shape[:2]
+        boxes = merge_boxes_list(
+            boxes_list, img_w, img_h, method="nms", iou_threshold=0.55
         )
         )
-    else:
-        raise RuntimeError(f"Invalid model name: {model_name}")
-
-
-class LayoutBox:
-    def __init__(
-        self,
-        clazz: int,
-        clazz_name: Optional[str],
-        bbox: List[int],
-        conf: float,
-    ):
-        self.clazz = clazz
-        self.clazz_name = clazz_name
-        self.bbox = bbox
-        self.conf = conf
-
-    @property
-    def ltrb(self):
-        l, t, r, b = self.bbox
-        return [int(x) for x in [l, t, r, b]]
-
-    @property
-    def area(self):
-        l, t, r, b = self.ltrb
-        return (r - l + 1) * (b - t + 1)
+        return boxes
 
 
-    def iou(self, other):
-        boxA = self.ltrb
-        boxB = other.ltrb
-        boxA = [int(x) for x in boxA]
-        boxB = [int(x) for x in boxB]
-
-        xA = max(boxA[0], boxB[0])
-        yA = max(boxA[1], boxB[1])
-        xB = min(boxA[2], boxB[2])
-        yB = min(boxA[3], boxB[3])
-
-        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
-
-        boxAArea = self.area
-        boxBArea = other.area
-
-        iou = interArea / float(boxAArea + boxBArea - interArea)
-
-        return iou
+    elif model_name == "ocr-layout-paddle":
+        from core.detectors.paddle_yolo import PaddleYoloDetector
 
 
-    def to_dict(self) -> Dict:
-        return {
-            "class": self.clazz,
-            "class_name": self.clazz_name,
-            "bbox": self.bbox,
-            "confidence": self.conf,
-        }
+        boxes = PaddleYoloDetector.predict(img, **kwargs)
 
 
-    def to_service_dict(self) -> Dict:
-        """
-        返回中间服务所需的格式
-        """
-        return {
-            "class": self.clazz,
-            "class_name": self.clazz_name,
-            "bbox": self.ltrb,
-            "confidence": self.conf,
-        }
+        # 该模型预测的类型暂时需要重新映射
+        # FIXME: 统一类型
+        _clazz_remap = {11: 10, 12: 11}
+        for b in boxes:
+            if b.clazz in _clazz_remap:
+                b.clazz = _clazz_remap[b.clazz]
 
 
-    @classmethod
-    def from_dict(cls, d) -> LayoutBox:
-        return cls(
-            clazz=d["class"],
-            clazz_name=d["class_name"],
-            bbox=d["bbox"],
-            conf=d["confidence"],
+        img_h, img_w = img.shape[:2]
+        boxes = clip_boxes_to_image_bound(boxes, img_w, img_h)
+        boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1)
+        boxes = filter_boxes_by_overlaps(
+            boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3
         )
         )
-
-    def __repr__(self):
-        return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"
+        return boxes
+    else:
+        raise RuntimeError(f"Invalid model name: {model_name}")

+ 8 - 24
server.py

@@ -7,7 +7,8 @@ import cv2
 import numpy as np
 import numpy as np
 
 
 import base64
 import base64
-from core.predictor import LayoutBox, predict_img
+from core.predictor import predict_img
+from core.layout import LayoutBox
 from sx_utils import format_print
 from sx_utils import format_print
 
 
 app = FastAPI()
 app = FastAPI()
@@ -31,7 +32,6 @@ clazz_names = [
     "style",
     "style",
     "table",
     "table",
     "text",
     "text",
-    "text_jhl",
     "title",
     "title",
 ]
 ]
 
 
@@ -69,7 +69,9 @@ def drag_and_drop_detect(request: Request):
 def detect_via_web_form(request: Request,
 def detect_via_web_form(request: Request,
                         file_list: List[UploadFile] = File(...),
                         file_list: List[UploadFile] = File(...),
                         model_name: str = Form(...),
                         model_name: str = Form(...),
-                        img_size: int = Form(1824)):
+                        img_size: int = Form(1824),
+                        multi_scale: bool = Form(False),
+                        ):
     '''
     '''
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
     Intended for human (non-api) users.
     Intended for human (non-api) users.
@@ -83,7 +85,7 @@ def detect_via_web_form(request: Request,
     # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
 
 
-    results = [predict_img(img, model_name, img_size) for img in img_batch_rgb]
+    results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
     
     
     json_results = boxes_list_to_json(results, clazz_names)
     json_results = boxes_list_to_json(results, clazz_names)
 
 
@@ -113,6 +115,7 @@ def detect_via_api(request: Request,
                    file_list: List[UploadFile] = File(...),
                    file_list: List[UploadFile] = File(...),
                    model_name: str = Form(...),
                    model_name: str = Form(...),
                    img_size: int = Form(1920),
                    img_size: int = Form(1920),
+                   multi_scale: bool = Form(False),
                    download_image: Optional[bool] = Form(False)):
                    download_image: Optional[bool] = Form(False)):
     '''
     '''
     Requires an image file upload, model name (ex. yolov5s).
     Requires an image file upload, model name (ex. yolov5s).
@@ -133,7 +136,7 @@ def detect_via_api(request: Request,
     # 转换图片格式
     # 转换图片格式
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
     # 选用相关模型进行模版识别
     # 选用相关模型进行模版识别
-    results = [predict_img(img, model_name, img_size) for img in img_batch_rgb]
+    results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
     # 处理结果数据
     # 处理结果数据
     json_results = boxes_list_to_json(results, clazz_names)
     json_results = boxes_list_to_json(results, clazz_names)
 
 
@@ -212,22 +215,3 @@ def boxes_list_to_json(boxes_list: List[List[LayoutBox]], clazz_names: List[str]
 def ping():
 def ping():
     print("->ping")
     print("->ping")
     return "pong!"
     return "pong!"
-
-
-# if __name__ == '__main__':
-#     import uvicorn
-#     import argparse
-#
-#     parser = argparse.ArgumentParser()
-#     parser.add_argument('--host', default='localhost')
-#     parser.add_argument('--port', default=8080)
-#     parser.add_argument('--precache-models', action='store_true',
-#                         help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
-#     opt = parser.parse_args()
-#
-#     # if opt.precache_models:
-#     #     model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
-#     #                     for model_name in model_selection_options}
-#
-#     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
-#     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 9 - 0
templates/drag_and_drop_detect.html

@@ -43,6 +43,14 @@ TODO:
         name="img_size"
         name="img_size"
         value="1824"
         value="1824"
       />
       />
+      <input
+        type="checkbox"
+        id="multi_scale"
+        name="multi_scale"
+      />
+      <label for="multi_scale" class="form-label">
+        <b>Multi-scale Inference</b>
+      </label>
     </div>
     </div>
 
 
     <div class="col">
     <div class="col">
@@ -329,6 +337,7 @@ TODO:
     formData.append('file_list', img)
     formData.append('file_list', img)
     formData.append('model_name', $('#model_name').val())
     formData.append('model_name', $('#model_name').val())
     formData.append('img_size', $('#img_size').val())
     formData.append('img_size', $('#img_size').val())
+    formData.append('multi_scale', $('#multi_scale').val())
 
 
     $.ajax({
     $.ajax({
       url: '/detect',
       url: '/detect',

+ 10 - 0
templates/home.html

@@ -43,6 +43,16 @@ Implements a simple Bootstrap 5 form submission interface for YOLOv5 detection.
         value="1824"
         value="1824"
       />
       />
     </div>
     </div>
+    <div class="my-2">
+      <input
+        type="checkbox"
+        id="multi_scale"
+        name="multi_scale"
+      />
+      <label for="multi_scale" class="form-label">
+        Multi-scale Inference
+      </label>
+    </div>
 
 
     <button class="btn btn-primary" type="submit">Submit</button>
     <button class="btn btn-primary" type="submit">Submit</button>
   </form>
   </form>