152 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			152 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import argparse | ||
|  | import os | ||
|  | from datetime import datetime | ||
|  | 
 | ||
|  | import gymnasium as gym | ||
|  | from gymnasium.envs.registration import register | ||
|  | from sb3_contrib import TQC | ||
|  | from stable_baselines3 import A2C, PPO, SAC, TD3 | ||
|  | 
 | ||
|  | 
 | ||
|  | def train(env, sb3_algo, model_dir, log_dir, pretrained=None, device="cuda"): | ||
|  |     # SAC parameters found here https://github.com/hill-a/stable-baselines/issues/840#issuecomment-623171534 | ||
|  |     if pretrained is None: | ||
|  |         match sb3_algo: | ||
|  |             case "SAC": | ||
|  |                 model = SAC( | ||
|  |                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |             case "TD3": | ||
|  |                 model = TD3( | ||
|  |                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |             case "A2C": | ||
|  |                 model = A2C( | ||
|  |                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |             case "TQC": | ||
|  |                 model = TQC( | ||
|  |                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |             case "PPO": | ||
|  |                 model = PPO( | ||
|  |                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |             case _: | ||
|  |                 print("Algorithm not found") | ||
|  |                 return | ||
|  |     else: | ||
|  |         match sb3_algo: | ||
|  |             case "SAC": | ||
|  |                 model = SAC.load( | ||
|  |                     pretrained, | ||
|  |                     env=env, | ||
|  |                     verbose=1, | ||
|  |                     device="cuda", | ||
|  |                     tensorboard_log=log_dir, | ||
|  |                 ) | ||
|  |             case "TD3": | ||
|  |                 model = TD3.load( | ||
|  |                     pretrained, | ||
|  |                     env=env, | ||
|  |                     verbose=1, | ||
|  |                     device="cuda", | ||
|  |                     tensorboard_log=log_dir, | ||
|  |                 ) | ||
|  |             case "A2C": | ||
|  |                 model = A2C.load( | ||
|  |                     pretrained, | ||
|  |                     env=env, | ||
|  |                     verbose=1, | ||
|  |                     device="cuda", | ||
|  |                     tensorboard_log=log_dir, | ||
|  |                 ) | ||
|  |             case "TQC": | ||
|  |                 model = TQC.load( | ||
|  |                     pretrained, | ||
|  |                     env=env, | ||
|  |                     verbose=1, | ||
|  |                     device="cuda", | ||
|  |                     tensorboard_log=log_dir, | ||
|  |                 ) | ||
|  |             case "PPO": | ||
|  |                 model = PPO( | ||
|  |                     "MlpPolicy", env, verbose=1, device="cuda", tensorboard_log=log_dir | ||
|  |                 ) | ||
|  |                 model.policy.load(pretrained) | ||
|  |                 # model = PPO.load( | ||
|  |                 #     pretrained, | ||
|  |                 #     env=env, | ||
|  |                 #     verbose=1, | ||
|  |                 #     device="cuda", | ||
|  |                 #     tensorboard_log=log_dir, | ||
|  |                 # ) | ||
|  |             case _: | ||
|  |                 print("Algorithm not found") | ||
|  |                 return | ||
|  | 
 | ||
|  |     TIMESTEPS = 10000 | ||
|  |     iters = 0 | ||
|  |     while True: | ||
|  |         iters += 1 | ||
|  | 
 | ||
|  |         model.learn( | ||
|  |             total_timesteps=TIMESTEPS, | ||
|  |             reset_num_timesteps=False, | ||
|  |             progress_bar=True, | ||
|  |         ) | ||
|  |         model.save(f"{model_dir}/{sb3_algo}_{TIMESTEPS*iters}") | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     parser = argparse.ArgumentParser(description="Train BDX") | ||
|  |     parser.add_argument( | ||
|  |         "-a", | ||
|  |         "--algo", | ||
|  |         type=str, | ||
|  |         choices=["SAC", "TD3", "A2C", "TQC", "PPO"], | ||
|  |         default="SAC", | ||
|  |     ) | ||
|  |     parser.add_argument("-p", "--pretrained", type=str, required=False) | ||
|  |     parser.add_argument("-d", "--device", type=str, required=False, default="cuda") | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         "-n", | ||
|  |         "--name", | ||
|  |         type=str, | ||
|  |         required=False, | ||
|  |         default=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), | ||
|  |         help="Name of the experiment", | ||
|  |     ) | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     register( | ||
|  |         id="BDX_env", | ||
|  |         entry_point="simple_env:BDXEnv", | ||
|  |         max_episode_steps=None,  # formerly 500 | ||
|  |         autoreset=True, | ||
|  |     ) | ||
|  |     # register( | ||
|  |     #     id="BDX_env", | ||
|  |     #     entry_point="env:BDXEnv", | ||
|  |     #     max_episode_steps=None,  # formerly 500 | ||
|  |     #     autoreset=True, | ||
|  |     # ) | ||
|  | 
 | ||
|  |     env = gym.make("BDX_env", render_mode=None) | ||
|  |     # Create directories to hold models and logs | ||
|  |     model_dir = args.name | ||
|  |     log_dir = "logs/" + args.name | ||
|  |     os.makedirs(model_dir, exist_ok=True) | ||
|  |     os.makedirs(log_dir, exist_ok=True) | ||
|  | 
 | ||
|  |     train( | ||
|  |         env, | ||
|  |         args.algo, | ||
|  |         pretrained=args.pretrained, | ||
|  |         model_dir=model_dir, | ||
|  |         log_dir=log_dir, | ||
|  |         device=args.device, | ||
|  |     ) |