# Ultralytics YOLO 🚀, AGPL-3.0 license import torch from ultralytics.models.yolo.segment import SegmentationPredictor from ultralytics.utils.metrics import box_iou from .utils import adjust_bboxes_to_image_border class FastSAMPredictor(SegmentationPredictor): """ FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics YOLO framework. This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single- class segmentation. """ def postprocess(self, preds, img, orig_imgs): """Applies box postprocess for FastSAM predictions.""" results = super().postprocess(preds, img, orig_imgs) for result in results: full_box = torch.tensor( [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32 ) boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape) idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten() if idx.numel() != 0: result.boxes.xyxy[idx] = full_box return results