131 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import os
 | |
| from glob import glob
 | |
| 
 | |
| import cv2
 | |
| import FramesViewer.utils as fv_utils
 | |
| import gymnasium as gym
 | |
| import mujoco
 | |
| import numpy as np
 | |
| from gymnasium.envs.registration import register
 | |
| from sb3_contrib import TQC
 | |
| from stable_baselines3 import A2C, PPO, SAC, TD3
 | |
| 
 | |
| register(
 | |
|     id="BDX_env",
 | |
|     entry_point="footsteps_env:BDXEnv",
 | |
|     autoreset=True,
 | |
|     # max_episode_steps=200,
 | |
| )
 | |
| 
 | |
| 
 | |
| def draw_clock(clock):
 | |
|     # clock [a, b]
 | |
|     clock_radius = 100
 | |
|     im = np.zeros((clock_radius * 2, clock_radius * 2, 3), np.uint8)
 | |
|     im = cv2.circle(im, (clock_radius, clock_radius), clock_radius, (255, 255, 255), -1)
 | |
|     im = cv2.line(
 | |
|         im,
 | |
|         (clock_radius, clock_radius),
 | |
|         (
 | |
|             int(clock_radius + clock_radius * clock[0]),
 | |
|             int(clock_radius + clock_radius * clock[1]),
 | |
|         ),
 | |
|         (0, 0, 255),
 | |
|         2,
 | |
|     )
 | |
|     cv2.imshow("clock", im)
 | |
|     cv2.waitKey(1)
 | |
| 
 | |
| 
 | |
| def draw_frame(pose, i, env):
 | |
|     pose = fv_utils.rotateInSelf(pose, [0, 90, 0])
 | |
|     # env.mujoco_renderer._get_viewer(render_mode="human")
 | |
|     env.mujoco_renderer._get_viewer(render_mode="human").add_marker(
 | |
|         pos=pose[:3, 3],
 | |
|         mat=pose[:3, :3],
 | |
|         size=[0.005, 0.005, 0.1],
 | |
|         type=mujoco.mjtGeom.mjGEOM_ARROW,
 | |
|         rgba=[1, 0, 0, 1],
 | |
|         label=str(i),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def test(env, sb3_algo, path_to_model):
 | |
|     if not path_to_model.endswith(".zip"):
 | |
|         models_paths = glob(path_to_model + "/*.zip")
 | |
|         latest_model_id = 0
 | |
|         latest_model_path = None
 | |
|         for model_path in models_paths:
 | |
|             model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
 | |
|             if int(model_id) > latest_model_id:
 | |
|                 latest_model_id = int(model_id)
 | |
|                 latest_model_path = model_path
 | |
| 
 | |
|         if latest_model_path is None:
 | |
|             print("No models found in directory: ", path_to_model)
 | |
|             return
 | |
| 
 | |
|         print("Using latest model: ", latest_model_path)
 | |
| 
 | |
|         path_to_model = latest_model_path
 | |
| 
 | |
|     match sb3_algo:
 | |
|         case "SAC":
 | |
|             model = SAC.load(path_to_model, env=env)
 | |
|         case "TD3":
 | |
|             model = TD3.load(path_to_model, env=env)
 | |
|         case "A2C":
 | |
|             model = A2C.load(path_to_model, env=env)
 | |
|         case "TQC":
 | |
|             model = TQC.load(path_to_model, env=env)
 | |
|         case "PPO":
 | |
|             model = PPO.load(path_to_model, env=env)
 | |
|         case _:
 | |
|             print("Algorithm not found")
 | |
|             return
 | |
| 
 | |
|     obs = env.reset()[0]
 | |
|     done = False
 | |
|     extra_steps = 500
 | |
|     while True:
 | |
|         action, _ = model.predict(obs)
 | |
|         obs, _, done, _, _ = env.step(action)
 | |
|         footsteps = env.next_footsteps
 | |
|         base_target_2D = np.mean(
 | |
|             [footsteps[2][:3, 3][:2], footsteps[3][:3, 3][:2]], axis=0
 | |
|         )
 | |
|         base_target_frame = np.eye(4)
 | |
|         base_target_frame[:3, 3][:2] = base_target_2D
 | |
|         draw_frame(base_target_frame, "base target", env)
 | |
|         base_pos_2D = env.data.body("base").xpos[:2]
 | |
|         base_pos_frame = np.eye(4)
 | |
|         base_pos_frame[:3, 3][:2] = base_pos_2D
 | |
|         draw_frame(base_pos_frame, "base pos", env)
 | |
| 
 | |
|         # draw_clock(env.get_clock_signal())
 | |
| 
 | |
|         for i, footstep in enumerate(footsteps[2:]):
 | |
|             draw_frame(footstep, i, env)
 | |
| 
 | |
|         if done:
 | |
|             extra_steps -= 1
 | |
| 
 | |
|             if extra_steps < 0:
 | |
|                 break
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Test model")
 | |
|     parser.add_argument(
 | |
|         "-p",
 | |
|         "--path",
 | |
|         metavar="path_to_model",
 | |
|         help="Path to the model. If directory, will use the latest model.",
 | |
|     )
 | |
|     parser.add_argument("-a", "--algo", default="SAC")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     gymenv = gym.make("BDX_env", render_mode="human")
 | |
|     test(gymenv, args.algo, path_to_model=args.path)
 |