import gymnasium as gym | |
from stable_baselines3 import PPO | |
# Create the environment (CartPole) | |
env = gym.make("CartPole-v1") | |
# Initialize PPO agent | |
model = PPO("MlpPolicy", env, verbose=1) | |
# Train the model for a few timesteps | |
model.learn(total_timesteps=10000) | |
# Save the trained model | |
model.save("ppo_cartpole") | |
# Test the trained model | |
obs, info = env.reset() # This will return both obs and info | |
while True: | |
action, _states = model.predict(obs) | |
obs, reward, done, truncated, info = env.step(action) # Return updated values for each step | |
if done: | |
obs, info = env.reset() # Reset the environment if done | |
env.close() | |