123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- from __future__ import annotations
- from typing import List
- import numpy as np
- from core.layout import LayoutBox
- from core.post_process import (
- clip_boxes_to_image_bound,
- filter_boxes_by_conf,
- filter_boxes_by_overlaps,
- merge_boxes_list,
- )
- def predict_img(
- img: np.ndarray,
- model_name: str,
- img_size: int,
- multi_scale: bool,
- **kwargs,
- ) -> List[LayoutBox]:
- """
- 根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
- """
- if model_name == "ocr-layout":
- from core.detectors.yolov7 import Yolov7Detector
- if not multi_scale:
- return Yolov7Detector.predict(img, img_size, **kwargs)
- scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5]
- img_sizes = [int(img_size * factor) for factor in scale_factors]
- boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes]
- img_h, img_w = img.shape[:2]
- boxes = merge_boxes_list(
- boxes_list, img_w, img_h, method="nms", iou_threshold=0.55
- )
- return boxes
- elif model_name == "ocr-layout-paddle":
- from core.detectors.paddle_yolo import PaddleYoloDetector
- boxes = PaddleYoloDetector.predict(img, **kwargs)
- # 该模型预测的类型暂时需要重新映射
- # FIXME: 统一类型
- _clazz_remap = {11: 10, 12: 11}
- for b in boxes:
- if b.clazz in _clazz_remap:
- b.clazz = _clazz_remap[b.clazz]
- img_h, img_w = img.shape[:2]
- boxes = clip_boxes_to_image_bound(boxes, img_w, img_h)
- boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1)
- boxes = filter_boxes_by_overlaps(
- boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3
- )
- return boxes
- else:
- raise RuntimeError(f"Invalid model name: {model_name}")
|