512 lines
17 KiB
Python
512 lines
17 KiB
Python
|
#!/usr/bin/env python
|
||
|
"""
|
||
|
\author Ma Yixiao
|
||
|
\date Dec. 2024
|
||
|
"""
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import actionlib
|
||
|
import PIL
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib
|
||
|
import skimage.io
|
||
|
import time
|
||
|
import logging
|
||
|
import threading
|
||
|
import copy
|
||
|
import cv2
|
||
|
from skimage.transform import resize
|
||
|
import message_filters
|
||
|
|
||
|
from cv_bridge import CvBridge, CvBridgeError
|
||
|
import math
|
||
|
import random
|
||
|
import sys
|
||
|
|
||
|
from maskrcnn_ros.msg import *
|
||
|
import maskrcnn_ros.msg
|
||
|
|
||
|
import ros_numpy
|
||
|
import numpy as np
|
||
|
import roslib.packages
|
||
|
import rospy
|
||
|
from sensor_msgs.msg import Image
|
||
|
from ultralytics import YOLO
|
||
|
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
|
||
|
from ultralytics_ros.msg import YoloResult
|
||
|
#myx : add for Time of each request
|
||
|
from std_msgs.msg import Float32
|
||
|
|
||
|
import colorsys
|
||
|
import random
|
||
|
|
||
|
import os
|
||
|
#os.environ["CUDA_VISIBLE_DEVICES"]="1"
|
||
|
#myx : gpu
|
||
|
import torch
|
||
|
|
||
|
print(torch.__version__)
|
||
|
print(torch.cuda.is_available())
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
print(torch.cuda.get_device_name(0))
|
||
|
print(torch.version.cuda)
|
||
|
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
print(device)
|
||
|
|
||
|
logging.basicConfig(filename='yolov8_action_server.log',
|
||
|
format='%(asctime)s %(levelname)s:%(message)s',
|
||
|
level=logging.INFO,
|
||
|
datefmt='%m/%d/%Y %I:%M:%S %p')
|
||
|
|
||
|
|
||
|
ROOT_DIR = os.path.abspath(
|
||
|
"/root/catkin_ws/src/ultralytics_ros/include/ultralytics/")
|
||
|
print("ROOT_DIR: ", ROOT_DIR)
|
||
|
|
||
|
sys.path.append(ROOT_DIR) # To find local version of the library
|
||
|
|
||
|
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))
|
||
|
|
||
|
|
||
|
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
|
||
|
IMAGE_DIR = os.path.join(ROOT_DIR, "images")
|
||
|
|
||
|
class_names = [
|
||
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||
|
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||
|
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||
|
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||
|
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||
|
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||
|
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||
|
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||
|
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||
|
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||
|
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||
|
'hair drier', 'toothbrush'
|
||
|
]
|
||
|
# Totoal time consuming
|
||
|
total_time = float(0.0)
|
||
|
# Total number of images
|
||
|
total_number = int(0)
|
||
|
|
||
|
# To sync semantic segmentation thread and action server thread
|
||
|
cn_task = threading.Condition()
|
||
|
cn_ready = threading.Condition()
|
||
|
|
||
|
|
||
|
batch_size = 2
|
||
|
#myx global results
|
||
|
def random_colors(N, bright=True):
|
||
|
"""
|
||
|
Generate random colors.
|
||
|
To get visually distinct colors, generate them in HSV space then
|
||
|
convert to RGB.
|
||
|
"""
|
||
|
brightness = 1.0 if bright else 0.7
|
||
|
hsv = [(float(i) / N, 1, brightness) for i in range(N)]
|
||
|
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
||
|
random.shuffle(colors)
|
||
|
return colors
|
||
|
|
||
|
def apply_mask(image, mask, color, alpha=0.5):
|
||
|
"""Apply the given mask to the image.
|
||
|
"""
|
||
|
for c in range(3):
|
||
|
image[:, :, c] = np.where(
|
||
|
mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255,
|
||
|
image[:, :, c])
|
||
|
return image
|
||
|
|
||
|
def ros_semantic_result(image,
|
||
|
boxes,
|
||
|
masks,
|
||
|
class_ids,
|
||
|
class_names,
|
||
|
scores=None,
|
||
|
show_mask=True,
|
||
|
show_bbox=True,
|
||
|
colors=None,
|
||
|
show_opencv=None):
|
||
|
|
||
|
N = boxes.shape[0]
|
||
|
if not N:
|
||
|
print("\n*** No instances to display *** \n")
|
||
|
else:
|
||
|
# print("boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]",boxes.shape[0],len(masks),class_ids.shape[0])
|
||
|
assert boxes.shape[0] == len(masks) == class_ids.shape[0]
|
||
|
|
||
|
# Generate random colors
|
||
|
colors = colors or random_colors(N)
|
||
|
|
||
|
height, width, channels = image.shape
|
||
|
print("h*w*c*dtype= ", height, width, channels, image.dtype.type)
|
||
|
if image is not None:
|
||
|
masked_image = image.astype(np.uint8).copy()
|
||
|
else:
|
||
|
masked_image = np.zeros(image.shape, dtype=np.uint8)
|
||
|
|
||
|
for i in range(N):
|
||
|
#myx : only people
|
||
|
# if class_ids[i] != 0:
|
||
|
# continue
|
||
|
|
||
|
color = colors[i]
|
||
|
|
||
|
# Bounding box
|
||
|
if not np.any(boxes[i]):
|
||
|
# Skip this instance. Has no bbox. Likely lost in image cropping.
|
||
|
continue
|
||
|
# y1, x1, y2, x2 = boxes[i].xyxy
|
||
|
# box = boxes[i].numpy()
|
||
|
#box = boxes[i].xyxy
|
||
|
# myx :
|
||
|
|
||
|
# x1 = box[0][0]
|
||
|
# y1 = box[0][1]
|
||
|
# x2 = box[0][2]
|
||
|
# y2 = box[0][3]
|
||
|
# # print("x1 y1 x2 y2",x1,y1,x2,y2)
|
||
|
|
||
|
# x1 = int(x1.item())
|
||
|
# y1 = int(y1.item())
|
||
|
# x2 = int(x2.item())
|
||
|
# y2 = int(y2.item())
|
||
|
|
||
|
|
||
|
# if show_bbox:
|
||
|
# masked_image = cv2.rectangle(image, (x1, y1), (x2, y2), (255,20,147), 1)
|
||
|
class_id = class_ids[i]
|
||
|
# score = scores[i] if scores is not None else None
|
||
|
# label = class_names[class_id]
|
||
|
# print("class ID: ", str(class_id), "label: ", str(label), " score: ", str(score))
|
||
|
mask = masks[i]
|
||
|
if show_mask:
|
||
|
masked_image = apply_mask(masked_image, mask, color)
|
||
|
#has the different color and boxes!!
|
||
|
cv2.imwrite('/root/Dataset/output!!!!!!!!!!!!@@@@@.jpg', masked_image)
|
||
|
|
||
|
# cv2.namedWindow("Result")
|
||
|
if show_opencv:
|
||
|
if N and np.any(masked_image):
|
||
|
cv2.imshow("Result", masked_image)
|
||
|
cv2.waitKey(1)
|
||
|
|
||
|
return masked_image
|
||
|
|
||
|
def save_image(image,i,id):
|
||
|
filename = os.path.join("/root/Dataset/", str(i) + "_" + str(id) + ".jpg")
|
||
|
cv2.imwrite(filename, image)
|
||
|
def save_mask_image(image,i,id):
|
||
|
filename = os.path.join("/root/Dataset/", str(i) + "___" + str(id) + ".jpg")
|
||
|
cv2.imwrite(filename, image)
|
||
|
|
||
|
class YoloV8(object):
|
||
|
def __init__(self):
|
||
|
print('Init YOLO')
|
||
|
|
||
|
yolo_model = rospy.get_param("~yolo_model", "yolov8n-seg.pt")
|
||
|
self.classes = rospy.get_param("~classes", class_names)
|
||
|
path = roslib.packages.get_pkg_dir("ultralytics_ros")
|
||
|
self.model = YOLO(f"{path}/models/{yolo_model}")
|
||
|
self.model.fuse()
|
||
|
self.masked_image__pub = rospy.Publisher("yolo_image", Image, queue_size=10)
|
||
|
self.use_segmentation = yolo_model.endswith("-seg.pt")
|
||
|
print("=============Initialized YOLO==============")
|
||
|
|
||
|
def segment(self, image):
|
||
|
timer_start = rospy.Time.now()
|
||
|
results = self.model.predict(
|
||
|
source=image,
|
||
|
device=0,
|
||
|
retina_masks=True,
|
||
|
)
|
||
|
# results = self.model(image)
|
||
|
|
||
|
segment_time = (rospy.Time.now() - timer_start).to_sec() * 1000
|
||
|
print("predict time: %f ms \n" % segment_time)
|
||
|
print("res.len" ,len(results))
|
||
|
|
||
|
#maybe the results is NULL
|
||
|
timer_start = rospy.Time.now()
|
||
|
#r = results[0]
|
||
|
if results[0].boxes.cls.numel() != 0:
|
||
|
boxes = results[0].boxes
|
||
|
# print("222222222222222222222222222",boxes.shape[0])
|
||
|
classes = results[0].boxes.cls
|
||
|
confidence_score = results[0].boxes.conf
|
||
|
class_ids = []
|
||
|
scores = []
|
||
|
masks = []
|
||
|
|
||
|
for cls, conf in zip(classes, confidence_score):
|
||
|
id = int(cls)
|
||
|
class_ids.append(id)
|
||
|
score = float(conf)
|
||
|
scores.append(score)
|
||
|
|
||
|
for mask_tensor in results[0].masks:
|
||
|
mask_numpy = (
|
||
|
np.squeeze(mask_tensor.data.to("cpu").detach().numpy()).astype(
|
||
|
np.uint8
|
||
|
)
|
||
|
)
|
||
|
masks.append(mask_numpy)
|
||
|
|
||
|
#masks = results[0].masks
|
||
|
scores = np.array(scores)
|
||
|
class_ids = np.array(class_ids)
|
||
|
# masks_count = masks.shape[-1]
|
||
|
|
||
|
#print("!!!!!!!!!!!!!masks_count = ",results[0].masks.shape[-1])
|
||
|
# print("masks_count = ",len(masks))
|
||
|
|
||
|
self.masked_image_ = ros_semantic_result(image[0],
|
||
|
boxes,
|
||
|
masks,
|
||
|
class_ids,
|
||
|
class_names,
|
||
|
scores,
|
||
|
show_opencv=False)
|
||
|
|
||
|
else:
|
||
|
print("..............111...............")
|
||
|
self.masked_image_ = np.zeros(image[0].shape, dtype=np.uint8)
|
||
|
# self.masked_image_ = image.astype(np.uint8).copy()
|
||
|
|
||
|
segment_time = (rospy.Time.now() - timer_start).to_sec() * 1000
|
||
|
print("Visualize time: %f ms \n" % segment_time)
|
||
|
|
||
|
return self.masked_image_, results
|
||
|
|
||
|
def publish_result(self, image, stamp):
|
||
|
if image is None or stamp is None:
|
||
|
print("Image invalid")
|
||
|
return
|
||
|
msg_img = CvBridge().cv2_to_imgmsg(image, encoding='passthrough')
|
||
|
msg_img.header.stamp = stamp
|
||
|
self.masked_image__pub.publish(msg_img)
|
||
|
|
||
|
def worker(yolov8, bPublish_result=False):
|
||
|
global bStartYOLO
|
||
|
global color_image
|
||
|
global stamp
|
||
|
global masked_image
|
||
|
global result
|
||
|
# global color_image_np
|
||
|
# global masked_image_msg
|
||
|
# global masked_msg
|
||
|
|
||
|
yolov8 = YoloV8()
|
||
|
bStartYOLO = True
|
||
|
color_image = []
|
||
|
#color_image_np = []
|
||
|
stamp = []
|
||
|
while bStartYOLO:
|
||
|
with cn_task:
|
||
|
cn_task.wait()
|
||
|
print("New task comming")
|
||
|
timer_start = rospy.Time.now()
|
||
|
print("color_image&&&&&&&&&&&&&&&&& len &&&&&&&&&&&",len(color_image))
|
||
|
masked_image, result = yolov8.segment(color_image)
|
||
|
|
||
|
segment_time = (rospy.Time.now() - timer_start).to_sec() * 1000
|
||
|
print("Yolo segment time for cuurent image: %f ms \n" %
|
||
|
segment_time)
|
||
|
with cn_ready:
|
||
|
cn_ready.notifyAll()
|
||
|
if bPublish_result:
|
||
|
yolov8.publish_result(masked_image, stamp)
|
||
|
|
||
|
print("Exit YoloV8 thread")
|
||
|
|
||
|
class SemanticActionServer(object):
|
||
|
_feedback = maskrcnn_ros.msg.batchFeedback()
|
||
|
_result = maskrcnn_ros.msg.batchResult()
|
||
|
|
||
|
def __init__(self):
|
||
|
print("Initialize Action Server")
|
||
|
|
||
|
# Get action server name
|
||
|
self._action_name = rospy.get_param('/semantic/action_name',
|
||
|
'/yolov8_action_server')
|
||
|
print("Action name: ", self._action_name)
|
||
|
# Start Action server
|
||
|
self._as = actionlib.SimpleActionServer(
|
||
|
self._action_name,
|
||
|
maskrcnn_ros.msg.batchAction,
|
||
|
execute_cb=self.execute_cb,
|
||
|
auto_start=False)
|
||
|
self._as.start()
|
||
|
print("YoloV8 action server start...")
|
||
|
#myx : Time of each request
|
||
|
#self.time_pub = rospy.Publisher('/time_of_request', Float32, queue_size = 1)
|
||
|
|
||
|
def execute_cb(self, goal):
|
||
|
global total_time
|
||
|
global total_number
|
||
|
global color_image
|
||
|
global stamp
|
||
|
global masked_image
|
||
|
# global masked_image_msg
|
||
|
# global masked_msg
|
||
|
global result
|
||
|
# global bSaveResult
|
||
|
|
||
|
# clear the old data
|
||
|
color_image = []
|
||
|
#color_image_np = []
|
||
|
stamp = []
|
||
|
|
||
|
if not self._as.is_active():
|
||
|
logging.debug("[Error] YoloV8 action server cannot active")
|
||
|
return
|
||
|
|
||
|
time_each_loop_start = rospy.Time.now()
|
||
|
|
||
|
print("----------------------------------")
|
||
|
print("ID: %d" % goal.id)
|
||
|
|
||
|
for i in range(batch_size):
|
||
|
try:
|
||
|
color_image.append(
|
||
|
CvBridge().imgmsg_to_cv2(goal.image[i], 'bgr8'))
|
||
|
print("Number of images:", len(goal.image))
|
||
|
# if goal.id >= 136 :
|
||
|
# save_image(CvBridge().imgmsg_to_cv2(goal.image[i], 'bgr8') ,i , goal.id)
|
||
|
stamp.append(goal.image[i].header.stamp)
|
||
|
|
||
|
except CvBridgeError as e:
|
||
|
print(e)
|
||
|
return
|
||
|
|
||
|
|
||
|
timer_start = rospy.Time.now()
|
||
|
# perform segmenting
|
||
|
if np.any(color_image[0]):
|
||
|
with cn_task:
|
||
|
print("Inform new task")
|
||
|
cn_task.notifyAll()
|
||
|
|
||
|
with cn_ready:
|
||
|
cn_ready.wait()
|
||
|
print("semantic result ready")
|
||
|
|
||
|
segment_time = (rospy.Time.now() - timer_start).to_sec() * 1000
|
||
|
print("Yolov8 segment time: %f ms \n" % segment_time)
|
||
|
# calculate average time consuming
|
||
|
total_time = float(total_time) + float(segment_time)
|
||
|
total_number = total_number + 1
|
||
|
if int(total_number) > 0:
|
||
|
average = total_time / float(total_number)
|
||
|
print("Average time: %f ms" % average)
|
||
|
|
||
|
# object_num = []
|
||
|
labelMsgs = []
|
||
|
# scoreMsgs = []
|
||
|
|
||
|
for i in range(batch_size):
|
||
|
if result[i].boxes.cls.numel() != 0:
|
||
|
|
||
|
box_es = result[i].boxes
|
||
|
# masks = result[i].masks
|
||
|
masks = []
|
||
|
for mask_tensor in result[i].masks:
|
||
|
mask_numpy = (
|
||
|
np.squeeze(mask_tensor.data.to("cpu").detach().numpy()).astype(
|
||
|
np.uint8
|
||
|
)
|
||
|
)
|
||
|
masks.append(mask_numpy)
|
||
|
# Total number of objects
|
||
|
N = box_es.shape[0]
|
||
|
|
||
|
|
||
|
label_img = np.zeros(masks[0].shape, dtype=np.uint8)
|
||
|
|
||
|
logging.info('shap of label image: %s', label_img.shape)
|
||
|
|
||
|
# file = open("/root/Dataset/outputtxt.txt", "w")
|
||
|
class_color = 10 # myx every person has the different color
|
||
|
person_color = 1
|
||
|
for n in range(N):
|
||
|
id = int(result[i].boxes.cls[n])
|
||
|
#print("..............N = ",N)
|
||
|
mask = masks[n] * 255
|
||
|
mask_area = np.sum(mask == 255) # 计算掩码的面积,即 mask 中像素值为 255 的个数
|
||
|
#print("......................Mask area............................ = ",mask_area)
|
||
|
if mask_area < 200:
|
||
|
continue
|
||
|
#person's id is 0!!! id + 10 ,input all objects' class
|
||
|
if id == 0:
|
||
|
label_img += (masks[n] * person_color ).astype(np.uint8)
|
||
|
person_color += 1
|
||
|
else:
|
||
|
label_img += (masks[n] * class_color ).astype(np.uint8)
|
||
|
class_color += 10 #the first object ID is 10, second is 20 ...
|
||
|
cv2.imwrite('/root/Dataset/output_all_mask.jpg', label_img)
|
||
|
msg_label = CvBridge().cv2_to_imgmsg(label_img.astype(np.uint8), encoding='mono8')
|
||
|
labelMsgs.append(msg_label)
|
||
|
|
||
|
else:
|
||
|
print(".................22222...................")
|
||
|
label_img = np.zeros_like(color_image[0])
|
||
|
label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
|
||
|
msg_label = CvBridge().cv2_to_imgmsg(label_img.astype(np.uint8), encoding='mono8')
|
||
|
labelMsgs.append(msg_label)
|
||
|
|
||
|
|
||
|
self._result.id = goal.id
|
||
|
self._result.label = labelMsgs # all objects' mask
|
||
|
|
||
|
self._feedback.complete = True
|
||
|
self._as.set_succeeded(self._result)
|
||
|
self._as.publish_feedback(self._feedback)
|
||
|
|
||
|
time_each_loop = (rospy.Time.now() -
|
||
|
time_each_loop_start).to_sec() * 1000
|
||
|
print("Time of each request: %f ms \n" % time_each_loop)
|
||
|
#myx :
|
||
|
#self.time_pub.publish(time_each_loop)
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
global total_time
|
||
|
global total_number
|
||
|
|
||
|
|
||
|
rospy.init_node('yolov8_server', anonymous=False)
|
||
|
|
||
|
th_yolov8 = threading.Thread(name='yolov8',
|
||
|
target=worker,
|
||
|
args=(
|
||
|
YoloV8,
|
||
|
False,
|
||
|
))
|
||
|
th_yolov8.start()
|
||
|
|
||
|
actionserver = SemanticActionServer()
|
||
|
|
||
|
print('Setting up YoloV8 Action Server...')
|
||
|
|
||
|
logging.debug('Waiting for worker threads')
|
||
|
bStartYOLO = False
|
||
|
main_thread = threading.currentThread()
|
||
|
|
||
|
try:
|
||
|
rospy.spin()
|
||
|
except KeyboardInterrupt:
|
||
|
print("Shutting down")
|
||
|
|
||
|
for t in threading.enumerate():
|
||
|
if t is not main_thread:
|
||
|
t.join()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main(sys.argv)
|