#!/usr/bin/env python3 import ros_numpy import numpy as np import roslib.packages import rospy from sensor_msgs.msg import Image from ultralytics import YOLO from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose from ultralytics_ros.msg import YoloResult class PredictNode: def __init__(self): yolo_model = rospy.get_param("~yolo_model", "yolov8n.pt") self.input_topic = rospy.get_param("~input_topic", "image_raw") self.result_topic = rospy.get_param("~result_topic", "yolo_result") self.result_image_topic = rospy.get_param("~result_image_topic", "yolo_image") self.conf_thres = rospy.get_param("~conf_thres", 0.25) self.iou_thres = rospy.get_param("~iou_thres", 0.45) self.max_det = rospy.get_param("~max_det", 300) self.classes = rospy.get_param("~classes", None) self.device = rospy.get_param("~device", None) self.result_conf = rospy.get_param("~result_conf", True) self.result_line_width = rospy.get_param("~result_line_width", None) self.result_font_size = rospy.get_param("~result_font_size", None) self.result_font = rospy.get_param("~result_font", "Arial.ttf") self.result_labels = rospy.get_param("~result_labels", True) self.result_boxes = rospy.get_param("~result_boxes", True) path = roslib.packages.get_pkg_dir("ultralytics_ros") self.model = YOLO(f"{path}/models/{yolo_model}") self.model.fuse() self.sub = rospy.Subscriber( self.input_topic, Image, self.image_callback, queue_size=1, buff_size=2**24, ) self.results_pub = rospy.Publisher(self.result_topic, YoloResult, queue_size=1) self.result_image_pub = rospy.Publisher( self.result_image_topic, Image, queue_size=1 ) self.use_segmentation = yolo_model.endswith("-seg.pt") def image_callback(self, msg): encoding = msg.encoding numpy_image = ros_numpy.numpify(msg) results = self.model.predict( source=numpy_image, conf=self.conf_thres, iou=self.iou_thres, max_det=self.max_det, classes=self.classes, device=self.device, verbose=False, retina_masks=True, ) if results is not None: yolo_result_msg = YoloResult() yolo_result_image_msg = Image() yolo_result_msg.header = msg.header yolo_result_image_msg.header = msg.header yolo_result_msg.detections = self.create_detections_array(results) yolo_result_image_msg = self.create_result_image(results, encoding) if self.use_segmentation: yolo_result_msg.masks = self.create_segmentation_masks(results) self.results_pub.publish(yolo_result_msg) self.result_image_pub.publish(yolo_result_image_msg) def create_detections_array(self, results): detections_msg = Detection2DArray() bounding_box = results[0].boxes.xywh classes = results[0].boxes.cls confidence_score = results[0].boxes.conf for bbox, cls, conf in zip(bounding_box, classes, confidence_score): detection = Detection2D() detection.bbox.center.x = float(bbox[0]) detection.bbox.center.y = float(bbox[1]) detection.bbox.size_x = float(bbox[2]) detection.bbox.size_y = float(bbox[3]) hypothesis = ObjectHypothesisWithPose() hypothesis.id = int(cls) hypothesis.score = float(conf) detection.results.append(hypothesis) detections_msg.detections.append(detection) return detections_msg def create_result_image(self, results, encoding): plotted_image = results[0].plot( conf=self.result_conf, line_width=self.result_line_width, font_size=self.result_font_size, font=self.result_font, labels=self.result_labels, boxes=self.result_boxes, ) result_image_msg = ros_numpy.msgify(Image, plotted_image, encoding=encoding) return result_image_msg def create_segmentation_masks(self, results): masks_msg = [] for result in results: if hasattr(result, "masks") and result.masks is not None: for mask_tensor in result.masks: mask_numpy = ( np.squeeze(mask_tensor.data.to("cpu").detach().numpy()).astype( np.uint8 ) * 255 ) mask_image_msg = ros_numpy.msgify( Image, mask_numpy, encoding="mono8" ) masks_msg.append(mask_image_msg) return masks_msg if __name__ == "__main__": rospy.init_node("predict_node") node = PredictNode() rospy.spin()