131 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			131 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import argparse | ||
|  | import os | ||
|  | from glob import glob | ||
|  | 
 | ||
|  | import cv2 | ||
|  | import FramesViewer.utils as fv_utils | ||
|  | import gymnasium as gym | ||
|  | import mujoco | ||
|  | import numpy as np | ||
|  | from gymnasium.envs.registration import register | ||
|  | from sb3_contrib import TQC | ||
|  | from stable_baselines3 import A2C, PPO, SAC, TD3 | ||
|  | 
 | ||
|  | register( | ||
|  |     id="BDX_env", | ||
|  |     entry_point="footsteps_env:BDXEnv", | ||
|  |     autoreset=True, | ||
|  |     # max_episode_steps=200, | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | def draw_clock(clock): | ||
|  |     # clock [a, b] | ||
|  |     clock_radius = 100 | ||
|  |     im = np.zeros((clock_radius * 2, clock_radius * 2, 3), np.uint8) | ||
|  |     im = cv2.circle(im, (clock_radius, clock_radius), clock_radius, (255, 255, 255), -1) | ||
|  |     im = cv2.line( | ||
|  |         im, | ||
|  |         (clock_radius, clock_radius), | ||
|  |         ( | ||
|  |             int(clock_radius + clock_radius * clock[0]), | ||
|  |             int(clock_radius + clock_radius * clock[1]), | ||
|  |         ), | ||
|  |         (0, 0, 255), | ||
|  |         2, | ||
|  |     ) | ||
|  |     cv2.imshow("clock", im) | ||
|  |     cv2.waitKey(1) | ||
|  | 
 | ||
|  | 
 | ||
|  | def draw_frame(pose, i, env): | ||
|  |     pose = fv_utils.rotateInSelf(pose, [0, 90, 0]) | ||
|  |     # env.mujoco_renderer._get_viewer(render_mode="human") | ||
|  |     env.mujoco_renderer._get_viewer(render_mode="human").add_marker( | ||
|  |         pos=pose[:3, 3], | ||
|  |         mat=pose[:3, :3], | ||
|  |         size=[0.005, 0.005, 0.1], | ||
|  |         type=mujoco.mjtGeom.mjGEOM_ARROW, | ||
|  |         rgba=[1, 0, 0, 1], | ||
|  |         label=str(i), | ||
|  |     ) | ||
|  | 
 | ||
|  | 
 | ||
|  | def test(env, sb3_algo, path_to_model): | ||
|  |     if not path_to_model.endswith(".zip"): | ||
|  |         models_paths = glob(path_to_model + "/*.zip") | ||
|  |         latest_model_id = 0 | ||
|  |         latest_model_path = None | ||
|  |         for model_path in models_paths: | ||
|  |             model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1] | ||
|  |             if int(model_id) > latest_model_id: | ||
|  |                 latest_model_id = int(model_id) | ||
|  |                 latest_model_path = model_path | ||
|  | 
 | ||
|  |         if latest_model_path is None: | ||
|  |             print("No models found in directory: ", path_to_model) | ||
|  |             return | ||
|  | 
 | ||
|  |         print("Using latest model: ", latest_model_path) | ||
|  | 
 | ||
|  |         path_to_model = latest_model_path | ||
|  | 
 | ||
|  |     match sb3_algo: | ||
|  |         case "SAC": | ||
|  |             model = SAC.load(path_to_model, env=env) | ||
|  |         case "TD3": | ||
|  |             model = TD3.load(path_to_model, env=env) | ||
|  |         case "A2C": | ||
|  |             model = A2C.load(path_to_model, env=env) | ||
|  |         case "TQC": | ||
|  |             model = TQC.load(path_to_model, env=env) | ||
|  |         case "PPO": | ||
|  |             model = PPO.load(path_to_model, env=env) | ||
|  |         case _: | ||
|  |             print("Algorithm not found") | ||
|  |             return | ||
|  | 
 | ||
|  |     obs = env.reset()[0] | ||
|  |     done = False | ||
|  |     extra_steps = 500 | ||
|  |     while True: | ||
|  |         action, _ = model.predict(obs) | ||
|  |         obs, _, done, _, _ = env.step(action) | ||
|  |         footsteps = env.next_footsteps | ||
|  |         base_target_2D = np.mean( | ||
|  |             [footsteps[2][:3, 3][:2], footsteps[3][:3, 3][:2]], axis=0 | ||
|  |         ) | ||
|  |         base_target_frame = np.eye(4) | ||
|  |         base_target_frame[:3, 3][:2] = base_target_2D | ||
|  |         draw_frame(base_target_frame, "base target", env) | ||
|  |         base_pos_2D = env.data.body("base").xpos[:2] | ||
|  |         base_pos_frame = np.eye(4) | ||
|  |         base_pos_frame[:3, 3][:2] = base_pos_2D | ||
|  |         draw_frame(base_pos_frame, "base pos", env) | ||
|  | 
 | ||
|  |         # draw_clock(env.get_clock_signal()) | ||
|  | 
 | ||
|  |         for i, footstep in enumerate(footsteps[2:]): | ||
|  |             draw_frame(footstep, i, env) | ||
|  | 
 | ||
|  |         if done: | ||
|  |             extra_steps -= 1 | ||
|  | 
 | ||
|  |             if extra_steps < 0: | ||
|  |                 break | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     parser = argparse.ArgumentParser(description="Test model") | ||
|  |     parser.add_argument( | ||
|  |         "-p", | ||
|  |         "--path", | ||
|  |         metavar="path_to_model", | ||
|  |         help="Path to the model. If directory, will use the latest model.", | ||
|  |     ) | ||
|  |     parser.add_argument("-a", "--algo", default="SAC") | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     gymenv = gym.make("BDX_env", render_mode="human") | ||
|  |     test(gymenv, args.algo, path_to_model=args.path) |