from __future__ import annotations from typing import List import numpy as np from core.layout import LayoutBox from core.post_process import ( clip_boxes_to_image_bound, filter_boxes_by_conf, filter_boxes_by_overlaps, merge_boxes_list, ) def predict_img( img: np.ndarray, model_name: str, img_size: int, multi_scale: bool, **kwargs, ) -> List[LayoutBox]: """ 根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。 """ if model_name == "ocr-layout": from core.detectors.yolov7 import Yolov7Detector if not multi_scale: return Yolov7Detector.predict(img, img_size, **kwargs) scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5] img_sizes = [int(img_size * factor) for factor in scale_factors] boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes] img_h, img_w = img.shape[:2] boxes = merge_boxes_list( boxes_list, img_w, img_h, method="nms", iou_threshold=0.55 ) return boxes elif model_name == "ocr-layout-paddle": from core.detectors.paddle_yolo import PaddleYoloDetector boxes = PaddleYoloDetector.predict(img, **kwargs) # 该模型预测的类型暂时需要重新映射 # FIXME: 统一类型 _clazz_remap = {11: 10, 12: 11} for b in boxes: if b.clazz in _clazz_remap: b.clazz = _clazz_remap[b.clazz] img_h, img_w = img.shape[:2] boxes = clip_boxes_to_image_bound(boxes, img_w, img_h) boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1) boxes = filter_boxes_by_overlaps( boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3 ) return boxes else: raise RuntimeError(f"Invalid model name: {model_name}")