65 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import pickle
 | |
| import pprint
 | |
| 
 | |
| import numpy as np
 | |
| from gymnasium.envs.registration import register
 | |
| from imitation.algorithms import density as db
 | |
| from imitation.data import serialize
 | |
| from imitation.util import util
 | |
| from stable_baselines3 import PPO
 | |
| from stable_baselines3.common.policies import ActorCriticPolicy
 | |
| 
 | |
| parser = argparse.ArgumentParser()
 | |
| parser.add_argument("-d", "--dataset", type=str, required=True)
 | |
| args = parser.parse_args()
 | |
| 
 | |
| rng = np.random.default_rng(0)
 | |
| 
 | |
| register(id="BDX_env", entry_point="env:BDXEnv")
 | |
| env = util.make_vec_env("BDX_env", rng=rng, n_envs=2)
 | |
| 
 | |
| dataset = pickle.load(open(args.dataset, "rb"))
 | |
| 
 | |
| imitation_trainer = PPO(
 | |
|     ActorCriticPolicy, env, learning_rate=3e-4, gamma=0.95, ent_coef=1e-4, n_steps=2048
 | |
| )
 | |
| density_trainer = db.DensityAlgorithm(
 | |
|     venv=env,
 | |
|     rng=rng,
 | |
|     demonstrations=dataset,
 | |
|     rl_algo=imitation_trainer,
 | |
|     density_type=db.DensityType.STATE_ACTION_DENSITY,
 | |
|     is_stationary=True,
 | |
|     kernel="gaussian",
 | |
|     kernel_bandwidth=0.4,
 | |
|     standardise_inputs=True,
 | |
|     allow_variable_horizon=True,
 | |
| )
 | |
| density_trainer.train()
 | |
| 
 | |
| 
 | |
| def print_stats(density_trainer, n_trajectories):
 | |
|     stats = density_trainer.test_policy(n_trajectories=n_trajectories)
 | |
|     print("True reward function stats:")
 | |
|     pprint.pprint(stats)
 | |
|     stats_im = density_trainer.test_policy(
 | |
|         true_reward=False, n_trajectories=n_trajectories
 | |
|     )
 | |
|     print("Imitation reward function stats:")
 | |
|     pprint.pprint(stats_im)
 | |
| 
 | |
| 
 | |
| print("Stats before training:")
 | |
| print_stats(density_trainer, 1)
 | |
| 
 | |
| density_trainer.train_policy(
 | |
|     1000000,
 | |
|     progress_bar=True,
 | |
| )  # Train for 1_000_000 steps to approach expert performance.
 | |
| 
 | |
| print("Stats after training:")
 | |
| print_stats(density_trainer, 1)
 | |
| 
 | |
| density_trainer.policy.save("density_policy_ppo.zip")
 |