GeoYolo-SLAM/ultralytics_ros/script/predict_node.py

124 lines
4.9 KiB
Python
Raw Normal View History

2025-04-09 16:05:54 +08:00
#!/usr/bin/env python3
import ros_numpy
import numpy as np
import roslib.packages
import rospy
from sensor_msgs.msg import Image
from ultralytics import YOLO
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from ultralytics_ros.msg import YoloResult
class PredictNode:
def __init__(self):
yolo_model = rospy.get_param("~yolo_model", "yolov8n.pt")
self.input_topic = rospy.get_param("~input_topic", "image_raw")
self.result_topic = rospy.get_param("~result_topic", "yolo_result")
self.result_image_topic = rospy.get_param("~result_image_topic", "yolo_image")
self.conf_thres = rospy.get_param("~conf_thres", 0.25)
self.iou_thres = rospy.get_param("~iou_thres", 0.45)
self.max_det = rospy.get_param("~max_det", 300)
self.classes = rospy.get_param("~classes", None)
self.device = rospy.get_param("~device", None)
self.result_conf = rospy.get_param("~result_conf", True)
self.result_line_width = rospy.get_param("~result_line_width", None)
self.result_font_size = rospy.get_param("~result_font_size", None)
self.result_font = rospy.get_param("~result_font", "Arial.ttf")
self.result_labels = rospy.get_param("~result_labels", True)
self.result_boxes = rospy.get_param("~result_boxes", True)
path = roslib.packages.get_pkg_dir("ultralytics_ros")
self.model = YOLO(f"{path}/models/{yolo_model}")
self.model.fuse()
self.sub = rospy.Subscriber(
self.input_topic,
Image,
self.image_callback,
queue_size=1,
buff_size=2**24,
)
self.results_pub = rospy.Publisher(self.result_topic, YoloResult, queue_size=1)
self.result_image_pub = rospy.Publisher(
self.result_image_topic, Image, queue_size=1
)
self.use_segmentation = yolo_model.endswith("-seg.pt")
def image_callback(self, msg):
encoding = msg.encoding
numpy_image = ros_numpy.numpify(msg)
results = self.model.predict(
source=numpy_image,
conf=self.conf_thres,
iou=self.iou_thres,
max_det=self.max_det,
classes=self.classes,
device=self.device,
verbose=False,
retina_masks=True,
)
if results is not None:
yolo_result_msg = YoloResult()
yolo_result_image_msg = Image()
yolo_result_msg.header = msg.header
yolo_result_image_msg.header = msg.header
yolo_result_msg.detections = self.create_detections_array(results)
yolo_result_image_msg = self.create_result_image(results, encoding)
if self.use_segmentation:
yolo_result_msg.masks = self.create_segmentation_masks(results)
self.results_pub.publish(yolo_result_msg)
self.result_image_pub.publish(yolo_result_image_msg)
def create_detections_array(self, results):
detections_msg = Detection2DArray()
bounding_box = results[0].boxes.xywh
classes = results[0].boxes.cls
confidence_score = results[0].boxes.conf
for bbox, cls, conf in zip(bounding_box, classes, confidence_score):
detection = Detection2D()
detection.bbox.center.x = float(bbox[0])
detection.bbox.center.y = float(bbox[1])
detection.bbox.size_x = float(bbox[2])
detection.bbox.size_y = float(bbox[3])
hypothesis = ObjectHypothesisWithPose()
hypothesis.id = int(cls)
hypothesis.score = float(conf)
detection.results.append(hypothesis)
detections_msg.detections.append(detection)
return detections_msg
def create_result_image(self, results, encoding):
plotted_image = results[0].plot(
conf=self.result_conf,
line_width=self.result_line_width,
font_size=self.result_font_size,
font=self.result_font,
labels=self.result_labels,
boxes=self.result_boxes,
)
result_image_msg = ros_numpy.msgify(Image, plotted_image, encoding=encoding)
return result_image_msg
def create_segmentation_masks(self, results):
masks_msg = []
for result in results:
if hasattr(result, "masks") and result.masks is not None:
for mask_tensor in result.masks:
mask_numpy = (
np.squeeze(mask_tensor.data.to("cpu").detach().numpy()).astype(
np.uint8
)
* 255
)
mask_image_msg = ros_numpy.msgify(
Image, mask_numpy, encoding="mono8"
)
masks_msg.append(mask_image_msg)
return masks_msg
if __name__ == "__main__":
rospy.init_node("predict_node")
node = PredictNode()
rospy.spin()