90 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			90 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import pickle
 | |
| 
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| from gymnasium.envs.registration import register
 | |
| from imitation.algorithms.adversarial.gail import GAIL
 | |
| from imitation.data.wrappers import RolloutInfoWrapper
 | |
| from imitation.rewards.reward_nets import BasicRewardNet
 | |
| from imitation.util.networks import RunningNorm
 | |
| from imitation.util.util import make_vec_env
 | |
| from stable_baselines3 import PPO
 | |
| from stable_baselines3.common.evaluation import evaluate_policy
 | |
| from stable_baselines3.ppo import MlpPolicy
 | |
| 
 | |
| # Check this out https://imitation.readthedocs.io/en/latest/algorithms/bc.html
 | |
| 
 | |
| parser = argparse.ArgumentParser()
 | |
| parser.add_argument("-d", "--dataset", type=str, required=True)
 | |
| args = parser.parse_args()
 | |
| 
 | |
| 
 | |
| dataset = pickle.load(open(args.dataset, "rb"))
 | |
| 
 | |
| register(id="BDX_env", entry_point="env:BDXEnv")
 | |
| 
 | |
| SEED = 42
 | |
| rng = np.random.default_rng(SEED)
 | |
| # env = gym.make("BDX_env", render_mode=None)
 | |
| env = make_vec_env(
 | |
|     "BDX_env",
 | |
|     rng=rng,
 | |
|     n_envs=8,
 | |
|     post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # to compute rollouts
 | |
| )
 | |
| 
 | |
| 
 | |
| learner = PPO(
 | |
|     env=env,
 | |
|     policy=MlpPolicy,
 | |
|     batch_size=64,
 | |
|     ent_coef=0.0,
 | |
|     learning_rate=0.0004,
 | |
|     gamma=0.95,
 | |
|     n_epochs=5,
 | |
|     seed=SEED,
 | |
|     tensorboard_log="logs",
 | |
| )
 | |
| reward_net = BasicRewardNet(
 | |
|     observation_space=env.observation_space,
 | |
|     action_space=env.action_space,
 | |
|     normalize_input_layer=RunningNorm,
 | |
| )
 | |
| gail_trainer = GAIL(
 | |
|     demonstrations=dataset,
 | |
|     demo_batch_size=1024,
 | |
|     gen_replay_buffer_capacity=512,
 | |
|     n_disc_updates_per_round=8,
 | |
|     venv=env,
 | |
|     gen_algo=learner,
 | |
|     reward_net=reward_net,
 | |
|     allow_variable_horizon=True,
 | |
| )
 | |
| 
 | |
| print("evaluate the learner before training")
 | |
| env.seed(SEED)
 | |
| learner_rewards_before_training, _ = evaluate_policy(
 | |
|     learner,
 | |
|     env,
 | |
|     100,
 | |
|     return_episode_rewards=True,
 | |
| )
 | |
| 
 | |
| print("train the learner and evaluate again")
 | |
| gail_trainer.train(500000)  # Train for 800_000 steps to match expert.
 | |
| 
 | |
| env.seed(SEED)
 | |
| learner_rewards_after_training, _ = evaluate_policy(
 | |
|     learner,
 | |
|     env,
 | |
|     100,
 | |
|     return_episode_rewards=True,
 | |
| )
 | |
| 
 | |
| print("mean episode reward before training:", np.mean(learner_rewards_before_training))
 | |
| print("mean episode reward after training:", np.mean(learner_rewards_after_training))
 | |
| 
 | |
| print("Save the policy")
 | |
| gail_trainer.policy.save("gail_policy_ppo.zip")
 |