File size: 672 Bytes
3b0a5c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import gymnasium as gym
from stable_baselines3 import PPO
# Create the environment (CartPole)
env = gym.make("CartPole-v1")
# Initialize PPO agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the model for a few timesteps
model.learn(total_timesteps=10000)
# Save the trained model
model.save("ppo_cartpole")
# Test the trained model
obs, info = env.reset() # This will return both obs and info
while True:
action, _states = model.predict(obs)
obs, reward, done, truncated, info = env.step(action) # Return updated values for each step
if done:
obs, info = env.reset() # Reset the environment if done
env.close()
|