84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import argparse | ||
|  | import os | ||
|  | from glob import glob | ||
|  | 
 | ||
|  | import gymnasium as gym | ||
|  | from gymnasium.envs.registration import register | ||
|  | from sb3_contrib import TQC | ||
|  | from stable_baselines3 import A2C, PPO, SAC, TD3 | ||
|  | 
 | ||
|  | register( | ||
|  |     id="BDX_env", | ||
|  |     entry_point="env_humanoid:BDXEnv", | ||
|  |     autoreset=True, | ||
|  |     # max_episode_steps=200, | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | def test(env, sb3_algo, path_to_model): | ||
|  | 
 | ||
|  |     if not path_to_model.endswith(".zip"): | ||
|  |         models_paths = glob(path_to_model + "/*.zip") | ||
|  |         latest_model_id = 0 | ||
|  |         latest_model_path = None | ||
|  |         for model_path in models_paths: | ||
|  |             model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1] | ||
|  |             if int(model_id) > latest_model_id: | ||
|  |                 latest_model_id = int(model_id) | ||
|  |                 latest_model_path = model_path | ||
|  | 
 | ||
|  |         if latest_model_path is None: | ||
|  |             print("No models found in directory: ", path_to_model) | ||
|  |             return | ||
|  | 
 | ||
|  |         print("Using latest model: ", latest_model_path) | ||
|  | 
 | ||
|  |         path_to_model = latest_model_path | ||
|  | 
 | ||
|  |     match sb3_algo: | ||
|  |         case "SAC": | ||
|  |             model = SAC.load(path_to_model, env=env) | ||
|  |         case "TD3": | ||
|  |             model = TD3.load(path_to_model, env=env) | ||
|  |         case "A2C": | ||
|  |             model = A2C.load(path_to_model, env=env) | ||
|  |         case "TQC": | ||
|  |             model = TQC.load(path_to_model, env=env) | ||
|  |         case "PPO": | ||
|  |             model = PPO("MlpPolicy", env) | ||
|  |             model.policy.load(path_to_model) | ||
|  |             # model = PPO.load(path_to_model, env=env) | ||
|  | 
 | ||
|  |         case _: | ||
|  |             print("Algorithm not found") | ||
|  |             return | ||
|  | 
 | ||
|  |     obs = env.reset()[0] | ||
|  |     done = False | ||
|  |     extra_steps = 500 | ||
|  |     while True: | ||
|  |         action, _ = model.predict(obs) | ||
|  |         obs, _, done, _, _ = env.step(action) | ||
|  | 
 | ||
|  |         if done: | ||
|  |             extra_steps -= 1 | ||
|  | 
 | ||
|  |             if extra_steps < 0: | ||
|  |                 break | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  | 
 | ||
|  |     parser = argparse.ArgumentParser(description="Test model") | ||
|  |     parser.add_argument( | ||
|  |         "-p", | ||
|  |         "--path", | ||
|  |         metavar="path_to_model", | ||
|  |         help="Path to the model. If directory, will use the latest model.", | ||
|  |     ) | ||
|  |     parser.add_argument("-a", "--algo", default="SAC") | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     gymenv = gym.make("BDX_env", render_mode="human") | ||
|  |     test(gymenv, args.algo, path_to_model=args.path) |