post_process.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import itertools
  2. from typing import List
  3. from core.layout import LayoutBox
  4. def merge_boxes_list(
  5. boxes_list: List[List[LayoutBox]],
  6. img_w: int,
  7. img_h: int,
  8. method: str = "nms",
  9. iou_threshold=0.5,
  10. ) -> List[LayoutBox]:
  11. """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。
  12. 可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。
  13. See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
  14. method
  15. Args:
  16. boxes_list (List[List[LayoutBox]]):
  17. 多组检测框列表
  18. img_w (int):
  19. 图像宽度
  20. img_h (int):
  21. 图像高度
  22. method (str, optional):
  23. 合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms".
  24. iou_threshold (float, optional):
  25. bbox 匹配的 IoU 阈值. Defaults to 0.5.
  26. Returns:
  27. List[LayoutBox]: 合并后的检测框列表
  28. """
  29. def ltrb_to_nltrb(ltrb, img_w, img_h):
  30. """
  31. Normalize ltrb.
  32. """
  33. l, t, r, b = ltrb
  34. nl = l / img_w
  35. nt = t / img_h
  36. nr = r / img_w
  37. nb = b / img_h
  38. return [nl, nt, nr, nb]
  39. def nltrb_to_ltrb(nltrb, img_w, img_h):
  40. """
  41. Denormalize normalized ltrb.
  42. """
  43. nl, nt, nr, nb = nltrb
  44. l = nl * img_w
  45. t = nt * img_h
  46. r = nr * img_w
  47. b = nb * img_h
  48. return [l, t, r, b]
  49. from ensemble_boxes import (
  50. nms,
  51. soft_nms,
  52. non_maximum_weighted,
  53. weighted_boxes_fusion,
  54. )
  55. merge_funcs = {
  56. "nms": nms,
  57. "soft_nms": soft_nms,
  58. "nmw": non_maximum_weighted,
  59. "wbf": weighted_boxes_fusion,
  60. }
  61. assert method in merge_funcs.keys()
  62. merge_func = merge_funcs[method]
  63. nltrbs_list = [
  64. [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes]
  65. for boxes in boxes_list
  66. ]
  67. scores_list = [[b.conf for b in boxes] for boxes in boxes_list]
  68. labels_list = [[b.clazz for b in boxes] for boxes in boxes_list]
  69. nltrbs, scores, labels = merge_func(
  70. nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold
  71. )
  72. merged_boxes = [
  73. LayoutBox(
  74. clazz=int(label),
  75. bbox=nltrb_to_ltrb(nltrb, img_w, img_h),
  76. conf=float(score),
  77. )
  78. for nltrb, score, label in zip(nltrbs, scores, labels)
  79. ]
  80. return merged_boxes
  81. def clip_boxes_to_image_bound(
  82. boxes: List[LayoutBox], img_w: int, img_h: int
  83. ) -> List[LayoutBox]:
  84. """
  85. 裁剪检测框尺寸以防止超出图像边界。
  86. """
  87. def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]:
  88. l, t, r, b = bbox
  89. l = max(0, int(l))
  90. t = max(0, int(t))
  91. r = min(img_w, int(r))
  92. b = min(img_h, int(b))
  93. return [l, t, r, b]
  94. for box in boxes:
  95. box.bbox = clip_bbox(box.bbox, img_w, img_h)
  96. return boxes
  97. def filter_boxes_by_conf(
  98. boxes: List[LayoutBox],
  99. conf_threshold: float,
  100. ) -> List[LayoutBox]:
  101. """
  102. 按置信度过滤检测框。
  103. """
  104. boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
  105. return boxes
  106. def filter_boxes_by_overlaps(
  107. boxes: List[LayoutBox],
  108. overlaps_iou_threshold: float,
  109. overlaps_max_count: int,
  110. ) -> List[LayoutBox]:
  111. """
  112. 按置信度和 IoU 过滤检测框。
  113. 对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
  114. """
  115. # 按置信度进行排序
  116. boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
  117. # 每一个桶中都是重叠区域较大的LayoutBox
  118. buckets: List[List[LayoutBox]] = []
  119. # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
  120. def get_bucket(box: LayoutBox):
  121. for bucket in buckets:
  122. for e in bucket:
  123. if box.iou(e) >= overlaps_iou_threshold:
  124. return bucket
  125. return None
  126. for box in boxes:
  127. bucket = get_bucket(box)
  128. # 若当前不存在于目标layout重叠的内容,则新建一个桶
  129. if not bucket:
  130. buckets.append([box])
  131. # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
  132. elif len(bucket) < overlaps_max_count:
  133. bucket.append(box)
  134. # 将所用桶中的数据合为一个列表
  135. new_boxes = list(itertools.chain.from_iterable(buckets))
  136. return new_boxes