from whisper_jax import FlaxWhisperForConditionalGeneration import jax.numpy as jnp checkpoint_id = "/media/user01/HDWINDOWS/whisper-medium-portuguese" # convert PyTorch weights to Flax model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True) # Save the converted Flax model in the same directory model.save_pretrained(checkpoint_id)