from whisper_jax import FlaxWhisperForConditionalGeneration | |
import jax.numpy as jnp | |
checkpoint_id = "/media/user01/HDWINDOWS/whisper-medium-portuguese" | |
# convert PyTorch weights to Flax | |
model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True) | |
# Save the converted Flax model in the same directory | |
model.save_pretrained(checkpoint_id) | |