152 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import os
 | |
| from datetime import datetime
 | |
| 
 | |
| import gymnasium as gym
 | |
| from gymnasium.envs.registration import register
 | |
| from sb3_contrib import TQC
 | |
| from stable_baselines3 import A2C, PPO, SAC, TD3
 | |
| 
 | |
| 
 | |
| def train(env, sb3_algo, model_dir, log_dir, pretrained=None, device="cuda"):
 | |
|     # SAC parameters found here https://github.com/hill-a/stable-baselines/issues/840#issuecomment-623171534
 | |
|     if pretrained is None:
 | |
|         match sb3_algo:
 | |
|             case "SAC":
 | |
|                 model = SAC(
 | |
|                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir
 | |
|                 )
 | |
|             case "TD3":
 | |
|                 model = TD3(
 | |
|                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir
 | |
|                 )
 | |
|             case "A2C":
 | |
|                 model = A2C(
 | |
|                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir
 | |
|                 )
 | |
|             case "TQC":
 | |
|                 model = TQC(
 | |
|                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir
 | |
|                 )
 | |
|             case "PPO":
 | |
|                 model = PPO(
 | |
|                     "MlpPolicy", env, verbose=1, device=device, tensorboard_log=log_dir
 | |
|                 )
 | |
|             case _:
 | |
|                 print("Algorithm not found")
 | |
|                 return
 | |
|     else:
 | |
|         match sb3_algo:
 | |
|             case "SAC":
 | |
|                 model = SAC.load(
 | |
|                     pretrained,
 | |
|                     env=env,
 | |
|                     verbose=1,
 | |
|                     device="cuda",
 | |
|                     tensorboard_log=log_dir,
 | |
|                 )
 | |
|             case "TD3":
 | |
|                 model = TD3.load(
 | |
|                     pretrained,
 | |
|                     env=env,
 | |
|                     verbose=1,
 | |
|                     device="cuda",
 | |
|                     tensorboard_log=log_dir,
 | |
|                 )
 | |
|             case "A2C":
 | |
|                 model = A2C.load(
 | |
|                     pretrained,
 | |
|                     env=env,
 | |
|                     verbose=1,
 | |
|                     device="cuda",
 | |
|                     tensorboard_log=log_dir,
 | |
|                 )
 | |
|             case "TQC":
 | |
|                 model = TQC.load(
 | |
|                     pretrained,
 | |
|                     env=env,
 | |
|                     verbose=1,
 | |
|                     device="cuda",
 | |
|                     tensorboard_log=log_dir,
 | |
|                 )
 | |
|             case "PPO":
 | |
|                 model = PPO(
 | |
|                     "MlpPolicy", env, verbose=1, device="cuda", tensorboard_log=log_dir
 | |
|                 )
 | |
|                 model.policy.load(pretrained)
 | |
|                 # model = PPO.load(
 | |
|                 #     pretrained,
 | |
|                 #     env=env,
 | |
|                 #     verbose=1,
 | |
|                 #     device="cuda",
 | |
|                 #     tensorboard_log=log_dir,
 | |
|                 # )
 | |
|             case _:
 | |
|                 print("Algorithm not found")
 | |
|                 return
 | |
| 
 | |
|     TIMESTEPS = 10000
 | |
|     iters = 0
 | |
|     while True:
 | |
|         iters += 1
 | |
| 
 | |
|         model.learn(
 | |
|             total_timesteps=TIMESTEPS,
 | |
|             reset_num_timesteps=False,
 | |
|             progress_bar=True,
 | |
|         )
 | |
|         model.save(f"{model_dir}/{sb3_algo}_{TIMESTEPS*iters}")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Train BDX")
 | |
|     parser.add_argument(
 | |
|         "-a",
 | |
|         "--algo",
 | |
|         type=str,
 | |
|         choices=["SAC", "TD3", "A2C", "TQC", "PPO"],
 | |
|         default="SAC",
 | |
|     )
 | |
|     parser.add_argument("-p", "--pretrained", type=str, required=False)
 | |
|     parser.add_argument("-d", "--device", type=str, required=False, default="cuda")
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "-n",
 | |
|         "--name",
 | |
|         type=str,
 | |
|         required=False,
 | |
|         default=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
 | |
|         help="Name of the experiment",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     register(
 | |
|         id="BDX_env",
 | |
|         entry_point="simple_env:BDXEnv",
 | |
|         max_episode_steps=None,  # formerly 500
 | |
|         autoreset=True,
 | |
|     )
 | |
|     # register(
 | |
|     #     id="BDX_env",
 | |
|     #     entry_point="env:BDXEnv",
 | |
|     #     max_episode_steps=None,  # formerly 500
 | |
|     #     autoreset=True,
 | |
|     # )
 | |
| 
 | |
|     env = gym.make("BDX_env", render_mode=None)
 | |
|     # Create directories to hold models and logs
 | |
|     model_dir = args.name
 | |
|     log_dir = "logs/" + args.name
 | |
|     os.makedirs(model_dir, exist_ok=True)
 | |
|     os.makedirs(log_dir, exist_ok=True)
 | |
| 
 | |
|     train(
 | |
|         env,
 | |
|         args.algo,
 | |
|         pretrained=args.pretrained,
 | |
|         model_dir=model_dir,
 | |
|         log_dir=log_dir,
 | |
|         device=args.device,
 | |
|     )
 |