42 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			42 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import pickle
 | |
| 
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| from gymnasium.envs.registration import register
 | |
| from imitation.algorithms import bc
 | |
| from stable_baselines3 import PPO
 | |
| from stable_baselines3.common.evaluation import evaluate_policy
 | |
| 
 | |
| # Check this out https://imitation.readthedocs.io/en/latest/algorithms/bc.html
 | |
| 
 | |
| parser = argparse.ArgumentParser()
 | |
| parser.add_argument("-d", "--dataset", type=str, required=True)
 | |
| args = parser.parse_args()
 | |
| 
 | |
| 
 | |
| dataset = pickle.load(open(args.dataset, "rb"))
 | |
| 
 | |
| register(id="BDX_env", entry_point="env_humanoid:BDXEnv")
 | |
| 
 | |
| env = gym.make("BDX_env", render_mode=None)
 | |
| 
 | |
| rng = np.random.default_rng(0)
 | |
| 
 | |
| bc_trainer = bc.BC(
 | |
|     observation_space=env.observation_space,
 | |
|     action_space=env.action_space,
 | |
|     demonstrations=dataset,
 | |
|     rng=rng,
 | |
|     device="cpu",
 | |
|     policy=PPO(
 | |
|         "MlpPolicy", env, policy_kwargs=dict(net_arch=[400, 300])
 | |
|     ).policy,  # not working with SAC for some reason
 | |
| )
 | |
| bc_trainer.train(n_epochs=10)
 | |
| 
 | |
| bc_trainer.policy.save("bc_policy_ppo.zip")
 | |
| 
 | |
| # reward, _ = evaluate_policy(bc_trainer.policy, env, 1)
 | |
| # print(reward)
 |