pose-detect/ultralytics/models/fastsam/predict.py

32 lines
1.3 KiB
Python

# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils.metrics import box_iou
from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
YOLO framework.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
class segmentation.
"""
def postprocess(self, preds, img, orig_imgs):
"""Applies box postprocess for FastSAM predictions."""
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return results