ppo_cartpol / main.py
jasondos's picture
Upload folder using huggingface_hub
3b0a5c8 verified
import gymnasium as gym
from stable_baselines3 import PPO
# Create the environment (CartPole)
env = gym.make("CartPole-v1")
# Initialize PPO agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the model for a few timesteps
model.learn(total_timesteps=10000)
# Save the trained model
model.save("ppo_cartpole")
# Test the trained model
obs, info = env.reset() # This will return both obs and info
while True:
action, _states = model.predict(obs)
obs, reward, done, truncated, info = env.step(action) # Return updated values for each step
if done:
obs, info = env.reset() # Reset the environment if done
env.close()