predictor.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. from typing import List
  3. import numpy as np
  4. from core.layout import LayoutBox
  5. from core.post_process import (
  6. clip_boxes_to_image_bound,
  7. filter_boxes_by_conf,
  8. filter_boxes_by_overlaps,
  9. merge_boxes_list,
  10. )
  11. def predict_img(
  12. img: np.ndarray,
  13. model_name: str,
  14. img_size: int,
  15. multi_scale: bool,
  16. **kwargs,
  17. ) -> List[LayoutBox]:
  18. """
  19. 根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
  20. """
  21. if model_name == "ocr-layout":
  22. from core.detectors.yolov7 import Yolov7Detector
  23. if not multi_scale:
  24. return Yolov7Detector.predict(img, img_size, **kwargs)
  25. scale_factors = [0.5, 0.75, 1.0, 1.25, 1.5]
  26. img_sizes = [int(img_size * factor) for factor in scale_factors]
  27. boxes_list = [Yolov7Detector.predict(img, size) for size in img_sizes]
  28. img_h, img_w = img.shape[:2]
  29. boxes = merge_boxes_list(
  30. boxes_list, img_w, img_h, method="nms", iou_threshold=0.55
  31. )
  32. return boxes
  33. elif model_name == "ocr-layout-paddle":
  34. from core.detectors.paddle_yolo import PaddleYoloDetector
  35. boxes = PaddleYoloDetector.predict(img, **kwargs)
  36. # 该模型预测的类型暂时需要重新映射
  37. # FIXME: 统一类型
  38. _clazz_remap = {11: 10, 12: 11}
  39. for b in boxes:
  40. if b.clazz in _clazz_remap:
  41. b.clazz = _clazz_remap[b.clazz]
  42. img_h, img_w = img.shape[:2]
  43. boxes = clip_boxes_to_image_bound(boxes, img_w, img_h)
  44. boxes = filter_boxes_by_conf(boxes, conf_threshold=0.1)
  45. boxes = filter_boxes_by_overlaps(
  46. boxes, overlaps_iou_threshold=0.8, overlaps_max_count=3
  47. )
  48. return boxes
  49. else:
  50. raise RuntimeError(f"Invalid model name: {model_name}")