RogerioFreitas's picture
Upload 22 files
b925bf3
raw
history blame contribute delete
368 Bytes
from whisper_jax import FlaxWhisperForConditionalGeneration
import jax.numpy as jnp
checkpoint_id = "/media/user01/HDWINDOWS/whisper-medium-portuguese"
# convert PyTorch weights to Flax
model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)
# Save the converted Flax model in the same directory
model.save_pretrained(checkpoint_id)