Porting Pi0-FAST to LeRobot from JAX to PyTorch: Challenges, Fixes, and Open Questions

Community Article Published April 2, 2025
Pi0+FAST is in LeRobot

Introduction

LeRobot Team has recently focused on porting Pi0+FAST originally implemented by Physical Intelligence to Lerobot repo. This post outlines the main modifications we made, challenges we faced, and key differences compared to the original implementation. Our goal is to open up a discussion and provide enough context for the community to contribute.

Background

Paper | Jax Code | Our implementation in Lerobot

π0-FAST is an autoregressive version of π0, introducing FAST (Frequency-space Action Sequence Tokenization) —a new tokenization scheme that enhances efficiency and performance.

Key Advantages of π0-FAST:

  • 5x faster training compared to diffusion-based VLAs.
  • Improved action representation, reducing redundancy in action sequences.
  • Stronger generalization across unseen environments and robot morphologies.

🔗 The π0-FAST tokenizer can be accessed here: FAST Tokenizer

🔗 Pretrained weights can be accessed here: Pytorch Pi0+FAST

Main Modifications in the PyTorch Implementation

  1. Vectorized padding and tokenization
  • Aligned with the tokenizers from Hugging Face's transformers library.
  1. Used Paligemma's built-in transformers implementation
  • With the main difference being a block causal mask similar to Pi0 but differing from Paligemma (which uses a full bidirectional mask for the prefix).
  1. Added custom prepare_inputs_for_generation
  • This was added in order to properly handle attention mask, position IDs, and other input processing details.
  1. Added some prefix adjustments
  • In contrast to the original Pi0+FAST implementation, which generates the word 'Action: ' in the output sequence, we added "Action: " to the prefix and passed it during training.
  1. No Exponential Moving Average (EMA) This implementation does not use EMA.

  2. Action Padding and Masking Adjustments

  • Added padding/truncation for action detokenization to ensure stable decoding in order to address the issues raised in this discussion.
  • Embedded action loss masks instead of using them explicitly and passed them to the model.
  • We used token type IDs to distinguish different components (prefix/suffix), ensuring a proper 4D attention mask for both training and inference.

Results:

  • All of these modifications resulted in a 40% success rate on LIBERO with optimized hyperparameters (now default in configuration_pi0fast.py).

Issues and Open Questions?

1. Output Action Token Inconsistencies

  • Passing the same input to JAX and PyTorch implementations does not always yield identical tokens.

2. Training Stability and Success Rate

  • Trained without EMA for simplicity—how crucial is EMA for fine-tuning?
  • JAX’s Pi0 does not attend to padded images, while JAX’s Pi0+FAST does. What is the rationale behind this design choice?
  • Maintaining a specific image order (e.g., exterior, wrist-left, wrist-right) might affect performance. Should this order be preserved?
  • Pi0 uses a block-causal mask, while Pi0-FAST uses bidirectional attention on prompts.
  • JAX model uses quantile normalization, while PyTorch uses mean/std.Is quantile normalization necessary for better performance?

3. Debugging Fine-Tuning Instability

  • Testing inference on the DROID example shows an MSE of 0.14 (vs. 0.01 for the JAX version).
  • Several generation tokens match, but the mismatch could be due to:
    • Training inconsistencies.
    • Differences in implementation.

Call for Community Contributions

To conclude, despite multiple attempts and fixes, SR remains below the expected/reported value:

  • Fine-tuning from base checkpoint reaches 30% SR.
  • Fine-tuning from LIBERO checkpoint reaches 60% SR before degrading.
  • Early training steps (5k) yield the highest SR, suggesting differences in training recipe.

The model should ideally achieve 80% SR, as reported in the original paper. We invite contributions to refine and improve this implementation.

Resources

Additional Resources

Community

Worth the wait. Go @LeRobot pips' !

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment