import os | |
import torch | |
from transformers import AutoTokenizer, pipeline | |
GPT_WEIGHTS_NAME = "pyg.pt" | |
def model_fn(model_dir): | |
model = torch.load(os.path.join(model_dir, GPT_WEIGHTS_NAME)) | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
if torch.cuda.is_available(): | |
device = 0 | |
else: | |
device = -1 | |
generation = pipeline( | |
"text-generation", model=model, tokenizer=tokenizer, device=device | |
) | |
return generation | |