119 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			119 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import time
 | |
| from glob import glob
 | |
| 
 | |
| import gymnasium as gym
 | |
| import mujoco
 | |
| import mujoco.viewer
 | |
| import numpy as np
 | |
| from gymnasium.envs.registration import register
 | |
| from stable_baselines3 import PPO, SAC
 | |
| 
 | |
| from mini_bdx.utils.mujoco_utils import check_contact
 | |
| 
 | |
| 
 | |
| def get_observation(data, left_contact, right_contact):
 | |
| 
 | |
|     position = (
 | |
|         data.qpos.flat.copy()
 | |
|     )  # position/rotation of trunk + position of all joints
 | |
|     velocity = (
 | |
|         data.qvel.flat.copy()
 | |
|     )  # positional/angular velocity of trunk +  of all joints
 | |
| 
 | |
|     obs = np.concatenate(
 | |
|         [
 | |
|             position,
 | |
|             velocity,
 | |
|             [left_contact, right_contact],
 | |
|         ]
 | |
|     )
 | |
|     # print("OBS SIZE", len(obs))
 | |
|     return obs
 | |
| 
 | |
| 
 | |
| def key_callback(keycode):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def get_model_from_dir(path_to_model):
 | |
| 
 | |
|     if not path_to_model.endswith(".zip"):
 | |
|         models_paths = glob(path_to_model + "/*.zip")
 | |
|         latest_model_id = 0
 | |
|         latest_model_path = None
 | |
|         for model_path in models_paths:
 | |
|             model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
 | |
|             if int(model_id) > latest_model_id:
 | |
|                 latest_model_id = int(model_id)
 | |
|                 latest_model_path = model_path
 | |
| 
 | |
|         if latest_model_path is None:
 | |
|             print("No models found in directory: ", path_to_model)
 | |
|             return
 | |
|     else:
 | |
|         latest_model_path = path_to_model
 | |
| 
 | |
|     return latest_model_path
 | |
| 
 | |
| 
 | |
| def get_feet_contact(data, model):
 | |
|     right_contact = check_contact(data, model, "foot_module", "floor")
 | |
|     left_contact = check_contact(data, model, "foot_module_2", "floor")
 | |
|     return right_contact, left_contact
 | |
| 
 | |
| 
 | |
| def play(env, path_to_model):
 | |
|     model_path = get_model_from_dir(path_to_model)
 | |
| 
 | |
|     model = mujoco.MjModel.from_xml_path("../../mini_bdx/robots/bdx/scene.xml")
 | |
|     data = mujoco.MjData(model)
 | |
| 
 | |
|     left_contact = False
 | |
|     right_contact = False
 | |
| 
 | |
|     viewer = mujoco.viewer.launch_passive(model, data, key_callback=key_callback)
 | |
| 
 | |
|     # nn_model = SAC.load(model_path, env)
 | |
| 
 | |
|     nn_model = PPO("MlpPolicy", env)
 | |
|     nn_model.policy.load(model_path)
 | |
| 
 | |
|     try:
 | |
|         while True:
 | |
| 
 | |
|             right_contact, left_contact = get_feet_contact(data, model)
 | |
|             obs = get_observation(
 | |
|                 data,
 | |
|                 left_contact,
 | |
|                 right_contact,
 | |
|             )
 | |
|             action, _ = nn_model.predict(obs)
 | |
|             data.ctrl[:] = action
 | |
| 
 | |
|             mujoco.mj_step(model, data)
 | |
|             viewer.sync()
 | |
|             time.sleep(model.opt.timestep / 2.5)
 | |
| 
 | |
|     except KeyboardInterrupt:
 | |
|         viewer.close()
 | |
| 
 | |
|     viewer.close()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     parser = argparse.ArgumentParser(description="Test model")
 | |
|     parser.add_argument(
 | |
|         "-p",
 | |
|         "--path",
 | |
|         metavar="path_to_model",
 | |
|         help="Path to the model. If directory, will use the latest model.",
 | |
|     )
 | |
|     parser.add_argument("-a", "--algo", default="SAC")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     register(id="BDX_env", entry_point="env_humanoid:BDXEnv")
 | |
|     env = gym.make("BDX_env", render_mode=None)
 | |
|     play(env, path_to_model=args.path)
 |