84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import os
 | |
| from glob import glob
 | |
| 
 | |
| import gymnasium as gym
 | |
| from gymnasium.envs.registration import register
 | |
| from sb3_contrib import TQC
 | |
| from stable_baselines3 import A2C, PPO, SAC, TD3
 | |
| 
 | |
| register(
 | |
|     id="BDX_env",
 | |
|     entry_point="env_humanoid:BDXEnv",
 | |
|     autoreset=True,
 | |
|     # max_episode_steps=200,
 | |
| )
 | |
| 
 | |
| 
 | |
| def test(env, sb3_algo, path_to_model):
 | |
| 
 | |
|     if not path_to_model.endswith(".zip"):
 | |
|         models_paths = glob(path_to_model + "/*.zip")
 | |
|         latest_model_id = 0
 | |
|         latest_model_path = None
 | |
|         for model_path in models_paths:
 | |
|             model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
 | |
|             if int(model_id) > latest_model_id:
 | |
|                 latest_model_id = int(model_id)
 | |
|                 latest_model_path = model_path
 | |
| 
 | |
|         if latest_model_path is None:
 | |
|             print("No models found in directory: ", path_to_model)
 | |
|             return
 | |
| 
 | |
|         print("Using latest model: ", latest_model_path)
 | |
| 
 | |
|         path_to_model = latest_model_path
 | |
| 
 | |
|     match sb3_algo:
 | |
|         case "SAC":
 | |
|             model = SAC.load(path_to_model, env=env)
 | |
|         case "TD3":
 | |
|             model = TD3.load(path_to_model, env=env)
 | |
|         case "A2C":
 | |
|             model = A2C.load(path_to_model, env=env)
 | |
|         case "TQC":
 | |
|             model = TQC.load(path_to_model, env=env)
 | |
|         case "PPO":
 | |
|             model = PPO("MlpPolicy", env)
 | |
|             model.policy.load(path_to_model)
 | |
|             # model = PPO.load(path_to_model, env=env)
 | |
| 
 | |
|         case _:
 | |
|             print("Algorithm not found")
 | |
|             return
 | |
| 
 | |
|     obs = env.reset()[0]
 | |
|     done = False
 | |
|     extra_steps = 500
 | |
|     while True:
 | |
|         action, _ = model.predict(obs)
 | |
|         obs, _, done, _, _ = env.step(action)
 | |
| 
 | |
|         if done:
 | |
|             extra_steps -= 1
 | |
| 
 | |
|             if extra_steps < 0:
 | |
|                 break
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     parser = argparse.ArgumentParser(description="Test model")
 | |
|     parser.add_argument(
 | |
|         "-p",
 | |
|         "--path",
 | |
|         metavar="path_to_model",
 | |
|         help="Path to the model. If directory, will use the latest model.",
 | |
|     )
 | |
|     parser.add_argument("-a", "--algo", default="SAC")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     gymenv = gym.make("BDX_env", render_mode="human")
 | |
|     test(gymenv, args.algo, path_to_model=args.path)
 |