predictor.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import annotations
  2. from typing import Dict, List, Optional
  3. import numpy as np
  4. def predict_img(
  5. img: np.ndarray, model_name: str, img_size: int, **kwargs
  6. ) -> List[LayoutBox]:
  7. """
  8. 根据模型名称预测布局检测框。模型仅在被调用时加载(以节省GPU资源)。
  9. """
  10. if model_name == "ocr-layout":
  11. from core.detectors.yolov7 import Yolov7Detector
  12. return Yolov7Detector.predict(img, img_size, **kwargs)
  13. elif model_name == "ocr-layout-paddle":
  14. from core.detectors.paddle_yolo import PaddleYoloDetector
  15. return PaddleYoloDetector.predict(
  16. img,
  17. conf_threshold=0.3,
  18. overlaps_iou_threshold=0.85,
  19. overlaps_max_count=3,
  20. **kwargs,
  21. )
  22. else:
  23. raise RuntimeError(f"Invalid model name: {model_name}")
  24. class LayoutBox:
  25. def __init__(
  26. self,
  27. clazz: int,
  28. clazz_name: Optional[str],
  29. bbox: List[int],
  30. conf: float,
  31. ):
  32. self.clazz = clazz
  33. self.clazz_name = clazz_name
  34. self.bbox = bbox
  35. self.conf = conf
  36. @property
  37. def ltrb(self):
  38. l, t, r, b = self.bbox
  39. return [int(x) for x in [l, t, r, b]]
  40. @property
  41. def area(self):
  42. l, t, r, b = self.ltrb
  43. return (r - l + 1) * (b - t + 1)
  44. def iou(self, other):
  45. boxA = self.ltrb
  46. boxB = other.ltrb
  47. boxA = [int(x) for x in boxA]
  48. boxB = [int(x) for x in boxB]
  49. xA = max(boxA[0], boxB[0])
  50. yA = max(boxA[1], boxB[1])
  51. xB = min(boxA[2], boxB[2])
  52. yB = min(boxA[3], boxB[3])
  53. interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
  54. boxAArea = self.area
  55. boxBArea = other.area
  56. iou = interArea / float(boxAArea + boxBArea - interArea)
  57. return iou
  58. def to_dict(self) -> Dict:
  59. return {
  60. "class": self.clazz,
  61. "class_name": self.clazz_name,
  62. "bbox": self.bbox,
  63. "confidence": self.conf,
  64. }
  65. def to_service_dict(self) -> Dict:
  66. """
  67. 返回中间服务所需的格式
  68. """
  69. return {
  70. "class": self.clazz,
  71. "class_name": self.clazz_name,
  72. "bbox": self.ltrb,
  73. "confidence": self.conf,
  74. }
  75. @classmethod
  76. def from_dict(cls, d) -> LayoutBox:
  77. return cls(
  78. clazz=d["class"],
  79. clazz_name=d["class_name"],
  80. bbox=d["bbox"],
  81. conf=d["confidence"],
  82. )
  83. def __repr__(self):
  84. return f"LayoutBox(class={self.clazz}, class_name={self.clazz_name}, bbox={self.bbox}, conf={self.conf})"