JAX/Flax Implementation
DeepMind's Gemma implementation does not seem to have been updated in accordance with the new release.
Are there any plans to release the JAX/Flax implementation and model?
There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?
For my own curiosity why are you interested in flax/jax in particular?
I think using TPU is the most cost-effective way to full fine-tune the 27B model.
Additionally, the JAX/Flax implementation is good to use as a reference implementation. Last time, in Gemma 1, DeepMind's implementation was the only one without bugs.
There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?
@canyon289 This would be very convenient. I want to integrate with our JORA library (Jax centered LLM PEFT finetuning). I believe the only differences from Gemma 1/1.1 are
- Logit softcaps,
- Sliding Window Attention, and
- query normalization
Plus, the weights in Flax format (i.e. orbax.checkpoint
)
Thank you both for the answers. There's a couple of other changes such as GQA! Regardless its still being worked on, it should be out soonish. My apologies for the delay
JORA looks interesting! I'd suggest adding a link to the paper in the readme.
We haven't forgotten about this. We're making some final changes and its on its way to release
I'd also suggest sending a PR to add it to https://github.com/n2cholas/awesome-jax
Its updated! Check out it folks. Hope you enjoy the models
@canyon289 Hi, could you check where the implementation with jax/flax of the model? I couldn't find python code related with gemma 2 implementation, rather, there are only weight files on Kaggle.
The official JAX repo has the configurations for Gemma 2: https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py