123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import itertools
- from typing import List
- from core.layout import LayoutBox
- def merge_boxes_list(
- boxes_list: List[List[LayoutBox]],
- img_w: int,
- img_h: int,
- method: str = "nms",
- iou_threshold=0.5,
- ) -> List[LayoutBox]:
- """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。
- 可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。
- See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
- method
- Args:
- boxes_list (List[List[LayoutBox]]):
- 多组检测框列表
- img_w (int):
- 图像宽度
- img_h (int):
- 图像高度
- method (str, optional):
- 合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms".
- iou_threshold (float, optional):
- bbox 匹配的 IoU 阈值. Defaults to 0.5.
- Returns:
- List[LayoutBox]: 合并后的检测框列表
- """
- def ltrb_to_nltrb(ltrb, img_w, img_h):
- """
- Normalize ltrb.
- """
- l, t, r, b = ltrb
- nl = l / img_w
- nt = t / img_h
- nr = r / img_w
- nb = b / img_h
- print(ltrb)
- print(img_h, img_w)
- return [nl, nt, nr, nb]
- def nltrb_to_ltrb(nltrb, img_w, img_h):
- """
- Denormalize normalized ltrb.
- """
- nl, nt, nr, nb = nltrb
- l = nl * img_w
- t = nt * img_h
- r = nr * img_w
- b = nb * img_h
- return [l, t, r, b]
- from ensemble_boxes import (
- nms,
- soft_nms,
- non_maximum_weighted,
- weighted_boxes_fusion,
- )
- merge_funcs = {
- "nms": nms,
- "soft_nms": soft_nms,
- "nmw": non_maximum_weighted,
- "wbf": weighted_boxes_fusion,
- }
- assert method in merge_funcs.keys()
- merge_func = merge_funcs[method]
- nltrbs_list = [
- [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes]
- for boxes in boxes_list
- ]
- scores_list = [[b.conf for b in boxes] for boxes in boxes_list]
- labels_list = [[b.clazz for b in boxes] for boxes in boxes_list]
- nltrbs, scores, labels = merge_func(
- nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold
- )
- merged_boxes = [
- LayoutBox(
- clazz=int(label),
- bbox=nltrb_to_ltrb(nltrb, img_w, img_h),
- conf=float(score),
- )
- for nltrb, score, label in zip(nltrbs, scores, labels)
- ]
- return merged_boxes
- def clip_boxes_to_image_bound(
- boxes: List[LayoutBox], img_w: int, img_h: int
- ) -> List[LayoutBox]:
- """
- 裁剪检测框尺寸以防止超出图像边界。
- """
- def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]:
- l, t, r, b = bbox
- l = max(0, int(l))
- t = max(0, int(t))
- r = min(img_w, int(r))
- b = min(img_h, int(b))
- return [l, t, r, b]
- for box in boxes:
- box.bbox = clip_bbox(box.bbox, img_w, img_h)
- return boxes
- def filter_boxes_by_conf(
- boxes: List[LayoutBox],
- conf_threshold: float,
- ) -> List[LayoutBox]:
- """
- 按置信度过滤检测框。
- """
- boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes))
- return boxes
- def filter_boxes_by_overlaps(
- boxes: List[LayoutBox],
- overlaps_iou_threshold: float,
- overlaps_max_count: int,
- ) -> List[LayoutBox]:
- """
- 按置信度和 IoU 过滤检测框。
- 对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。
- """
- # 按置信度进行排序
- boxes = sorted(boxes, key=lambda e: e.conf, reverse=True)
- # 每一个桶中都是重叠区域较大的LayoutBox
- buckets: List[List[LayoutBox]] = []
- # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶
- def get_bucket(box: LayoutBox):
- for bucket in buckets:
- for e in bucket:
- if box.iou(e) >= overlaps_iou_threshold:
- return bucket
- return None
- for box in boxes:
- bucket = get_bucket(box)
- # 若当前不存在于目标layout重叠的内容,则新建一个桶
- if not bucket:
- buckets.append([box])
- # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域
- elif len(bucket) < overlaps_max_count:
- bucket.append(box)
- # 将所用桶中的数据合为一个列表
- new_boxes = list(itertools.chain.from_iterable(buckets))
- return new_boxes
|