post_process.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import itertools
  2. from typing import List
  3. from core.layout import LayoutBox
  4. def merge_boxes_list(
  5. boxes_list: List[List[LayoutBox]],
  6. img_w: int,
  7. img_h: int,
  8. method: str = "nms",
  9. iou_threshold=0.5,
  10. ) -> List[LayoutBox]:
  11. """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。
  12. 可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。
  13. See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
  14. method
  15. Args:
  16. boxes_list (List[List[LayoutBox]]):
  17. 多组检测框列表
  18. img_w (int):
  19. 图像宽度
  20. img_h (int):
  21. 图像高度
  22. method (str, optional):
  23. 合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms".
  24. iou_threshold (float, optional):
  25. bbox 匹配的 IoU 阈值. Defaults to 0.5.
  26. Returns:
  27. List[LayoutBox]: 合并后的检测框列表
  28. """
  29. def ltrb_to_nltrb(ltrb, img_w, img_h):
  30. """
  31. Normalize ltrb.
  32. """
  33. l, t, r, b = ltrb
  34. nl = l / img_w
  35. nt = t / img_h
  36. nr = r / img_w
  37. nb = b / img_h
  38. print(ltrb)
  39. print(img_h, img_w)
  40. return [nl, nt, nr, nb]
  41. def nltrb_to_ltrb(nltrb, img_w, img_h):
  42. """
  43. Denormalize normalized ltrb.
  44. """
  45. nl, nt, nr, nb = nltrb
  46. l = nl * img_w
  47. t = nt * img_h
  48. r = nr * img_w
  49. b = nb * img_h
  50. return [l, t, r, b]
  51. from ensemble_boxes import (
  52. nms,
  53. soft_nms,
  54. non_maximum_weighted,
  55. weighted_boxes_fusion,
  56. )
  57. merge_funcs = {
  58. "nms": nms,
  59. "soft_nms": soft_nms,
  60. "nmw": non_maximum_weighted,
  61. "wbf": weighted_boxes_fusion,
  62. }
  63. assert method in merge_funcs.keys()
  64. merge_func = merge_funcs[method]
  65. nltrbs_list = [
  66. [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes]
  67. for boxes in boxes_list
  68. ]
  69. scores_list = [[b.conf for b in boxes] for boxes in boxes_list]
  70. labels_list = [[b.clazz for b in boxes] for boxes in boxes_list]
  71. nltrbs, scores, labels = merge_func(
  72. nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold
  73. )
  74. merged_boxes = [
  75. LayoutBox(
  76. clazz=int(label),
  77. bbox=nltrb_to_ltrb(nltrb, img_w, img_h),
  78. conf=float(score),
  79. )
  80. for nltrb, score, label in zip(nltrbs, scores, labels)
  81. ]
  82. return merged_boxes
  83. def clip_boxes_to_image_bound(
  84. boxes: List[LayoutBox], img_w: int, img_h: int
  85. ) -> List[LayoutBox]:
  86. """
  87. 裁剪检测框尺寸以防止超出图像边界。
  88. """
  89. def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]:
  90. l, t, r, b = bbox
  91. l = max(0, int(l))
  92. t = max(0, int(t))
  93. r = min(img_w, int(r))
  94. b = min(img_h, int(b))
  95. return [l, t, r, b]
  96. for box in boxes:
  97. box.bbox = clip_bbox(box.bbox, img_w, img_h)
  98. return boxes
  99. def filter_boxes_by_conf(
  100. boxes: List[LayoutBox],
  101. conf_threshold: float,
  102. ) -> List[LayoutBox]:
  103. """
  104. 按置信度过滤检测框。
  105. """
  106. boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
  107. return boxes
  108. def filter_boxes_by_overlaps(
  109. boxes: List[LayoutBox],
  110. overlaps_iou_threshold: float,
  111. overlaps_max_count: int,
  112. ) -> List[LayoutBox]:
  113. """
  114. 按置信度和 IoU 过滤检测框。
  115. 对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
  116. """
  117. # 按置信度进行排序
  118. boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
  119. # 每一个桶中都是重叠区域较大的LayoutBox
  120. buckets: List[List[LayoutBox]] = []
  121. # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
  122. def get_bucket(box: LayoutBox):
  123. for bucket in buckets:
  124. for e in bucket:
  125. if box.iou(e) >= overlaps_iou_threshold:
  126. return bucket
  127. return None
  128. for box in boxes:
  129. bucket = get_bucket(box)
  130. # 若当前不存在于目标layout重叠的内容,则新建一个桶
  131. if not bucket:
  132. buckets.append([box])
  133. # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
  134. elif len(bucket) < overlaps_max_count:
  135. bucket.append(box)
  136. # 将所用桶中的数据合为一个列表
  137. new_boxes = list(itertools.chain.from_iterable(buckets))
  138. return new_boxes