Robotics
Transformers
Safetensors
Inference Endpoints
alexandersoare's picture
Upload folder using huggingface_hub
e6d30a8 verified
|
raw
history blame
2.44 kB
metadata
license: apache-2.0
datasets:
  - lerobot/pusht_keypoints
pipeline_tag: robotics

Model Card for Diffusion Policy / PushT (keypoints)

Diffusion Policy (as per Diffusion Policy: Visuomotor Policy Learning via Action Diffusion) trained for the PushT environment from gym-pusht with keypoint-only observations.

How to Get Started with the Model

Use python lerobot/scripts/eval.py -p lerobot/diffusion_pusht to evaluate for 50 episodes with the outputs sent to outputs/eval.

For further information, please see the LeRobot library (particularly the evaluation script).

Training Details

Trained with TODO commit hash

The model was trained using LeRobot's training script and with the pusht_keypoints dataset, using this command:

python lerobot/scripts/train.py \
  hydra.job.name=diffusion_pusht_keypoints \
  hydra.run.dir=outputs/train/2024-07-03/13-52-44_diffusion_pusht_keypoints \
  env=pusht_keypoints \
  policy=diffusion_pusht_keypoints \
  training.save_checkpoint=true \
  training.offline_steps=200000 \
  training.save_freq=20000 \
  training.eval_freq=10000 \
  training.log_freq=50 \
  training.num_workers=4 \
  eval.n_episodes=50 \
  eval.batch_size=50 \
  wandb.enable=true \
  wandb.disable_artifact=true \
  device=cuda \
  use_amp=true

The training curves may be found at https://wandb.ai/alexander-soare/lerobot/runs/5z9d8q9q/overview.

This took about 5 hours to train on an Nvida RTX H100.

Evaluation

The model was evaluated on the PushT environment from gym-pusht. There are two evaluation metrics on a per-episode basis:

  • Maximum overlap with target (seen as eval/avg_max_reward in the charts above). This ranges in [0, 1].
  • Success: whether or not the maximum overlap is at least 95%.

Here are the metrics for 500 episodes worth of evaluation.

Metric|Average over 500 episodes -|-|- Average max. overlap ratio | 0.97 Success rate (%) | 71.0

The results of each of the individual rollouts may be found in eval_info.json.