import argparse | |
from transformers import RobertaForMaskedLM | |
def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"): | |
""" | |
Converts Flax model weights to PyTorch weights. | |
""" | |
model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True) | |
model.save_pretrained(torch_model_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Flax to Pytorch model coversion") | |
parser.add_argument( | |
"--flax_model_path", type=str, default="flax-community/roberta-pretraining-hindi", help="Flax model path" | |
) | |
parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path") | |
args = parser.parse_args() | |
convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path) | |