|
@@ -1,11 +1,23 @@
|
|
|
from __future__ import annotations
|
|
|
-from typing import Dict, List, Optional
|
|
|
+from typing import List
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
+from core.layout import LayoutBox
|
|
|
+from core.post_process import (
|
|
|
+ clip_boxes_to_image_bound,
|
|
|
+ filter_boxes_by_conf,
|
|
|
+ filter_boxes_by_overlaps,
|
|
|
+ merge_boxes_list,
|
|
|
+)
|
|
|
+
|
|
|
|
|
|
def predict_img(
|
|
|
- img: np.ndarray, model_name: str, img_size: int, **kwargs
|
|
|
+ img: np.ndarray,
|
|
|
+ model_name: str,
|
|
|
+ img_size: int,
|
|
|
+ multi_scale: bool,
|
|
|
+ **kwargs,
|
|
|
) -> List[LayoutBox]:
|
|
|
"""
|
|
|
根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
|
|
@@ -13,92 +25,36 @@ def predict_img(
|
|
|
if model_name == "ocr-layout":
|
|
|
from core.detectors.yolov7 import Yolov7Detector
|
|
|
|
|
|
- return Yolov7Detector.predict(img, img_size, **kwargs)
|
|
|
+ if not multi_scale:
|
|
|
+ return Yolov7Detector.predict(img, img_size, **kwargs)
|
|
|
|
|
|
- elif model_name == "ocr-layout-paddle":
|
|
|
- from core.detectors.paddle_yolo import PaddleYoloDetector
|
|
|
-
|
|
|
- return PaddleYoloDetector.predict(
|
|
|
- img,
|
|
|
- conf_threshold=0.3,
|
|
|
- overlaps_iou_threshold=0.85,
|
|
|
- overlaps_max_count=3,
|
|
|
- **kwargs,
|
|
|
+ scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5]
|
|
|
+ img_sizes = [int(img_size * factor) for factor in scale_factors]
|
|
|
+ boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes]
|
|
|
+ img_h, img_w = img.shape[:2]
|
|
|
+ boxes = merge_boxes_list(
|
|
|
+ boxes_list, img_w, img_h, method="nms", iou_threshold=0.55
|
|
|
)
|
|
|
- else:
|
|
|
- raise RuntimeError(f"Invalid model name: {model_name}")
|
|
|
-
|
|
|
-
|
|
|
-class LayoutBox:
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- clazz: int,
|
|
|
- clazz_name: Optional[str],
|
|
|
- bbox: List[int],
|
|
|
- conf: float,
|
|
|
- ):
|
|
|
- self.clazz = clazz
|
|
|
- self.clazz_name = clazz_name
|
|
|
- self.bbox = bbox
|
|
|
- self.conf = conf
|
|
|
-
|
|
|
- @property
|
|
|
- def ltrb(self):
|
|
|
- l, t, r, b = self.bbox
|
|
|
- return [int(x) for x in [l, t, r, b]]
|
|
|
-
|
|
|
- @property
|
|
|
- def area(self):
|
|
|
- l, t, r, b = self.ltrb
|
|
|
- return (r - l + 1) * (b - t + 1)
|
|
|
+ return boxes
|
|
|
|
|
|
- def iou(self, other):
|
|
|
- boxA = self.ltrb
|
|
|
- boxB = other.ltrb
|
|
|
- boxA = [int(x) for x in boxA]
|
|
|
- boxB = [int(x) for x in boxB]
|
|
|
-
|
|
|
- xA = max(boxA[0], boxB[0])
|
|
|
- yA = max(boxA[1], boxB[1])
|
|
|
- xB = min(boxA[2], boxB[2])
|
|
|
- yB = min(boxA[3], boxB[3])
|
|
|
-
|
|
|
- interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
|
|
-
|
|
|
- boxAArea = self.area
|
|
|
- boxBArea = other.area
|
|
|
-
|
|
|
- iou = interArea / float(boxAArea + boxBArea - interArea)
|
|
|
-
|
|
|
- return iou
|
|
|
+ elif model_name == "ocr-layout-paddle":
|
|
|
+ from core.detectors.paddle_yolo import PaddleYoloDetector
|
|
|
|
|
|
- def to_dict(self) -> Dict:
|
|
|
- return {
|
|
|
- "class": self.clazz,
|
|
|
- "class_name": self.clazz_name,
|
|
|
- "bbox": self.bbox,
|
|
|
- "confidence": self.conf,
|
|
|
- }
|
|
|
+ boxes = PaddleYoloDetector.predict(img, **kwargs)
|
|
|
|
|
|
- def to_service_dict(self) -> Dict:
|
|
|
- """
|
|
|
- 返回中间服务所需的格式
|
|
|
- """
|
|
|
- return {
|
|
|
- "class": self.clazz,
|
|
|
- "class_name": self.clazz_name,
|
|
|
- "bbox": self.ltrb,
|
|
|
- "confidence": self.conf,
|
|
|
- }
|
|
|
+ # 该模型预测的类型暂时需要重新映射
|
|
|
+ # FIXME: 统一类型
|
|
|
+ _clazz_remap = {11: 10, 12: 11}
|
|
|
+ for b in boxes:
|
|
|
+ if b.clazz in _clazz_remap:
|
|
|
+ b.clazz = _clazz_remap[b.clazz]
|
|
|
|
|
|
- @classmethod
|
|
|
- def from_dict(cls, d) -> LayoutBox:
|
|
|
- return cls(
|
|
|
- clazz=d["class"],
|
|
|
- clazz_name=d["class_name"],
|
|
|
- bbox=d["bbox"],
|
|
|
- conf=d["confidence"],
|
|
|
+ img_h, img_w = img.shape[:2]
|
|
|
+ boxes = clip_boxes_to_image_bound(boxes, img_w, img_h)
|
|
|
+ boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1)
|
|
|
+ boxes = filter_boxes_by_overlaps(
|
|
|
+ boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3
|
|
|
)
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"
|
|
|
+ return boxes
|
|
|
+ else:
|
|
|
+ raise RuntimeError(f"Invalid model name: {model_name}")
|