124 lines
4.9 KiB
Python
Executable File
124 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import ros_numpy
|
|
import numpy as np
|
|
import roslib.packages
|
|
import rospy
|
|
from sensor_msgs.msg import Image
|
|
from ultralytics import YOLO
|
|
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
|
|
from ultralytics_ros.msg import YoloResult
|
|
|
|
|
|
class PredictNode:
|
|
def __init__(self):
|
|
yolo_model = rospy.get_param("~yolo_model", "yolov8n.pt")
|
|
self.input_topic = rospy.get_param("~input_topic", "image_raw")
|
|
self.result_topic = rospy.get_param("~result_topic", "yolo_result")
|
|
self.result_image_topic = rospy.get_param("~result_image_topic", "yolo_image")
|
|
self.conf_thres = rospy.get_param("~conf_thres", 0.25)
|
|
self.iou_thres = rospy.get_param("~iou_thres", 0.45)
|
|
self.max_det = rospy.get_param("~max_det", 300)
|
|
self.classes = rospy.get_param("~classes", None)
|
|
self.device = rospy.get_param("~device", None)
|
|
self.result_conf = rospy.get_param("~result_conf", True)
|
|
self.result_line_width = rospy.get_param("~result_line_width", None)
|
|
self.result_font_size = rospy.get_param("~result_font_size", None)
|
|
self.result_font = rospy.get_param("~result_font", "Arial.ttf")
|
|
self.result_labels = rospy.get_param("~result_labels", True)
|
|
self.result_boxes = rospy.get_param("~result_boxes", True)
|
|
path = roslib.packages.get_pkg_dir("ultralytics_ros")
|
|
self.model = YOLO(f"{path}/models/{yolo_model}")
|
|
self.model.fuse()
|
|
self.sub = rospy.Subscriber(
|
|
self.input_topic,
|
|
Image,
|
|
self.image_callback,
|
|
queue_size=1,
|
|
buff_size=2**24,
|
|
)
|
|
self.results_pub = rospy.Publisher(self.result_topic, YoloResult, queue_size=1)
|
|
self.result_image_pub = rospy.Publisher(
|
|
self.result_image_topic, Image, queue_size=1
|
|
)
|
|
self.use_segmentation = yolo_model.endswith("-seg.pt")
|
|
|
|
def image_callback(self, msg):
|
|
encoding = msg.encoding
|
|
numpy_image = ros_numpy.numpify(msg)
|
|
results = self.model.predict(
|
|
source=numpy_image,
|
|
conf=self.conf_thres,
|
|
iou=self.iou_thres,
|
|
max_det=self.max_det,
|
|
classes=self.classes,
|
|
device=self.device,
|
|
verbose=False,
|
|
retina_masks=True,
|
|
)
|
|
|
|
if results is not None:
|
|
yolo_result_msg = YoloResult()
|
|
yolo_result_image_msg = Image()
|
|
yolo_result_msg.header = msg.header
|
|
yolo_result_image_msg.header = msg.header
|
|
yolo_result_msg.detections = self.create_detections_array(results)
|
|
yolo_result_image_msg = self.create_result_image(results, encoding)
|
|
if self.use_segmentation:
|
|
yolo_result_msg.masks = self.create_segmentation_masks(results)
|
|
self.results_pub.publish(yolo_result_msg)
|
|
self.result_image_pub.publish(yolo_result_image_msg)
|
|
|
|
def create_detections_array(self, results):
|
|
detections_msg = Detection2DArray()
|
|
bounding_box = results[0].boxes.xywh
|
|
classes = results[0].boxes.cls
|
|
confidence_score = results[0].boxes.conf
|
|
for bbox, cls, conf in zip(bounding_box, classes, confidence_score):
|
|
detection = Detection2D()
|
|
detection.bbox.center.x = float(bbox[0])
|
|
detection.bbox.center.y = float(bbox[1])
|
|
detection.bbox.size_x = float(bbox[2])
|
|
detection.bbox.size_y = float(bbox[3])
|
|
hypothesis = ObjectHypothesisWithPose()
|
|
hypothesis.id = int(cls)
|
|
hypothesis.score = float(conf)
|
|
detection.results.append(hypothesis)
|
|
detections_msg.detections.append(detection)
|
|
return detections_msg
|
|
|
|
def create_result_image(self, results, encoding):
|
|
plotted_image = results[0].plot(
|
|
conf=self.result_conf,
|
|
line_width=self.result_line_width,
|
|
font_size=self.result_font_size,
|
|
font=self.result_font,
|
|
labels=self.result_labels,
|
|
boxes=self.result_boxes,
|
|
)
|
|
result_image_msg = ros_numpy.msgify(Image, plotted_image, encoding=encoding)
|
|
return result_image_msg
|
|
|
|
def create_segmentation_masks(self, results):
|
|
masks_msg = []
|
|
for result in results:
|
|
if hasattr(result, "masks") and result.masks is not None:
|
|
for mask_tensor in result.masks:
|
|
mask_numpy = (
|
|
np.squeeze(mask_tensor.data.to("cpu").detach().numpy()).astype(
|
|
np.uint8
|
|
)
|
|
* 255
|
|
)
|
|
mask_image_msg = ros_numpy.msgify(
|
|
Image, mask_numpy, encoding="mono8"
|
|
)
|
|
masks_msg.append(mask_image_msg)
|
|
return masks_msg
|
|
|
|
|
|
if __name__ == "__main__":
|
|
rospy.init_node("predict_node")
|
|
node = PredictNode()
|
|
rospy.spin()
|