import itertools from typing import List from core.layout import LayoutBox def merge_boxes_list( boxes_list: List[List[LayoutBox]], img_w: int, img_h: int, method: str = "nms", iou_threshold=0.5, ) -> List[LayoutBox]: """合并多组检测框列表,调用 Weighted-Boxes-Fusion 库实现。 可用于合并多个模型的预测结果,或合并单个模型多次的预测结果。 See: https://github.com/ZFTurbo/Weighted-Boxes-Fusion method Args: boxes_list (List[List[LayoutBox]]): 多组检测框列表 img_w (int): 图像宽度 img_h (int): 图像高度 method (str, optional): 合并方法名,可选值: ["nms", "soft_nms", "nmw", "wbf"]. Defaults to "nms". iou_threshold (float, optional): bbox 匹配的 IoU 阈值. Defaults to 0.5. Returns: List[LayoutBox]: 合并后的检测框列表 """ def ltrb_to_nltrb(ltrb, img_w, img_h): """ Normalize ltrb. """ l, t, r, b = ltrb nl = l / img_w nt = t / img_h nr = r / img_w nb = b / img_h print(ltrb) print(img_h, img_w) return [nl, nt, nr, nb] def nltrb_to_ltrb(nltrb, img_w, img_h): """ Denormalize normalized ltrb. """ nl, nt, nr, nb = nltrb l = nl * img_w t = nt * img_h r = nr * img_w b = nb * img_h return [l, t, r, b] from ensemble_boxes import ( nms, soft_nms, non_maximum_weighted, weighted_boxes_fusion, ) merge_funcs = { "nms": nms, "soft_nms": soft_nms, "nmw": non_maximum_weighted, "wbf": weighted_boxes_fusion, } assert method in merge_funcs.keys() merge_func = merge_funcs[method] nltrbs_list = [ [ltrb_to_nltrb(b.ltrb, img_w, img_h) for b in boxes] for boxes in boxes_list ] scores_list = [[b.conf for b in boxes] for boxes in boxes_list] labels_list = [[b.clazz for b in boxes] for boxes in boxes_list] nltrbs, scores, labels = merge_func( nltrbs_list, scores_list, labels_list, iou_thr=iou_threshold ) merged_boxes = [ LayoutBox( clazz=int(label), bbox=nltrb_to_ltrb(nltrb, img_w, img_h), conf=float(score), ) for nltrb, score, label in zip(nltrbs, scores, labels) ] return merged_boxes def clip_boxes_to_image_bound( boxes: List[LayoutBox], img_w: int, img_h: int ) -> List[LayoutBox]: """ 裁剪检测框尺寸以防止超出图像边界。 """ def clip_bbox(bbox: List[int], img_w: int, img_h: int) -> List[int]: l, t, r, b = bbox l = max(0, int(l)) t = max(0, int(t)) r = min(img_w, int(r)) b = min(img_h, int(b)) return [l, t, r, b] for box in boxes: box.bbox = clip_bbox(box.bbox, img_w, img_h) return boxes def filter_boxes_by_conf( boxes: List[LayoutBox], conf_threshold: float, ) -> List[LayoutBox]: """ 按置信度过滤检测框。 """ boxes = list(filter(lambda e: e.conf >= conf_threshold, boxes)) return boxes def filter_boxes_by_overlaps( boxes: List[LayoutBox], overlaps_iou_threshold: float, overlaps_max_count: int, ) -> List[LayoutBox]: """ 按置信度和 IoU 过滤检测框。 对多个 IoU 大于 `overlaps_iou_threshold` 的区域,仅保留 `overlaps_max_count` 个置信度最高的。 """ # 按置信度进行排序 boxes = sorted(boxes, key=lambda e: e.conf, reverse=True) # 每一个桶中都是重叠区域较大的LayoutBox buckets: List[List[LayoutBox]] = [] # 将目标于每一个桶中的每一个LayoutBox进行比较,找到目标应该存在于哪一个桶 def get_bucket(box: LayoutBox): for bucket in buckets: for e in bucket: if box.iou(e) >= overlaps_iou_threshold: return bucket return None for box in boxes: bucket = get_bucket(box) # 若当前不存在于目标layout重叠的内容,则新建一个桶 if not bucket: buckets.append([box]) # 若找到目标应该位于的桶,则只收取置信度较高的overlaps_max_count个框选区域 elif len(bucket) < overlaps_max_count: bucket.append(box) # 将所用桶中的数据合为一个列表 new_boxes = list(itertools.chain.from_iterable(buckets)) return new_boxes