GPT2-Jokes / export_h5.py
ameerazam08's picture
Create export_h5.py
ebc6d95
raw
history blame contribute delete
554 Bytes
from transformers import TFGPT2LMHeadModel, GPT2Config, GPT2LMHeadModel
# Load your trained PyTorch model
pytorch_model_path = "trained_path"
config = GPT2Config.from_pretrained(pytorch_model_path)
pytorch_model = GPT2LMHeadModel.from_pretrained(pytorch_model_path, config=config,from_tf=True)
# Convert to TensorFlow model
tf_model = TFGPT2LMHeadModel.from_pretrained(pytorch_model_path, from_pt=True, config=config)
# Save the TensorFlow model
tf_model.save_pretrained(pytorch_model_path) # This will generate the tf_model.h5 file in the directory