snowpark
Collection
JAX ports of popular models(audio, vision, llm, diffusion, etc)
•
1 item
•
Updated
I implemented Dinov2 in JAX for TPU inference, and converted the pretrained weights to JAX. Model code here. Only thing here is the pickle file for the converted pytree/state dict or weights.
Base model
facebook/dinov2-small