90 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			90 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import argparse | ||
|  | import pickle | ||
|  | 
 | ||
|  | import gymnasium as gym | ||
|  | import numpy as np | ||
|  | from gymnasium.envs.registration import register | ||
|  | from imitation.algorithms.adversarial.gail import GAIL | ||
|  | from imitation.data.wrappers import RolloutInfoWrapper | ||
|  | from imitation.rewards.reward_nets import BasicRewardNet | ||
|  | from imitation.util.networks import RunningNorm | ||
|  | from imitation.util.util import make_vec_env | ||
|  | from stable_baselines3 import PPO | ||
|  | from stable_baselines3.common.evaluation import evaluate_policy | ||
|  | from stable_baselines3.ppo import MlpPolicy | ||
|  | 
 | ||
|  | # Check this out https://imitation.readthedocs.io/en/latest/algorithms/bc.html | ||
|  | 
 | ||
|  | parser = argparse.ArgumentParser() | ||
|  | parser.add_argument("-d", "--dataset", type=str, required=True) | ||
|  | args = parser.parse_args() | ||
|  | 
 | ||
|  | 
 | ||
|  | dataset = pickle.load(open(args.dataset, "rb")) | ||
|  | 
 | ||
|  | register(id="BDX_env", entry_point="env:BDXEnv") | ||
|  | 
 | ||
|  | SEED = 42 | ||
|  | rng = np.random.default_rng(SEED) | ||
|  | # env = gym.make("BDX_env", render_mode=None) | ||
|  | env = make_vec_env( | ||
|  |     "BDX_env", | ||
|  |     rng=rng, | ||
|  |     n_envs=8, | ||
|  |     post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # to compute rollouts | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | learner = PPO( | ||
|  |     env=env, | ||
|  |     policy=MlpPolicy, | ||
|  |     batch_size=64, | ||
|  |     ent_coef=0.0, | ||
|  |     learning_rate=0.0004, | ||
|  |     gamma=0.95, | ||
|  |     n_epochs=5, | ||
|  |     seed=SEED, | ||
|  |     tensorboard_log="logs", | ||
|  | ) | ||
|  | reward_net = BasicRewardNet( | ||
|  |     observation_space=env.observation_space, | ||
|  |     action_space=env.action_space, | ||
|  |     normalize_input_layer=RunningNorm, | ||
|  | ) | ||
|  | gail_trainer = GAIL( | ||
|  |     demonstrations=dataset, | ||
|  |     demo_batch_size=1024, | ||
|  |     gen_replay_buffer_capacity=512, | ||
|  |     n_disc_updates_per_round=8, | ||
|  |     venv=env, | ||
|  |     gen_algo=learner, | ||
|  |     reward_net=reward_net, | ||
|  |     allow_variable_horizon=True, | ||
|  | ) | ||
|  | 
 | ||
|  | print("evaluate the learner before training") | ||
|  | env.seed(SEED) | ||
|  | learner_rewards_before_training, _ = evaluate_policy( | ||
|  |     learner, | ||
|  |     env, | ||
|  |     100, | ||
|  |     return_episode_rewards=True, | ||
|  | ) | ||
|  | 
 | ||
|  | print("train the learner and evaluate again") | ||
|  | gail_trainer.train(500000)  # Train for 800_000 steps to match expert. | ||
|  | 
 | ||
|  | env.seed(SEED) | ||
|  | learner_rewards_after_training, _ = evaluate_policy( | ||
|  |     learner, | ||
|  |     env, | ||
|  |     100, | ||
|  |     return_episode_rewards=True, | ||
|  | ) | ||
|  | 
 | ||
|  | print("mean episode reward before training:", np.mean(learner_rewards_before_training)) | ||
|  | print("mean episode reward after training:", np.mean(learner_rewards_after_training)) | ||
|  | 
 | ||
|  | print("Save the policy") | ||
|  | gail_trainer.policy.save("gail_policy_ppo.zip") |