pyg-pt / code /inference.py
MrD05's picture
Upload 8 files
e290a20
import os
import torch
from transformers import AutoTokenizer, pipeline
GPT_WEIGHTS_NAME = "pyg.pt"
def model_fn(model_dir):
model = torch.load(os.path.join(model_dir, GPT_WEIGHTS_NAME))
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if torch.cuda.is_available():
device = 0
else:
device = -1
generation = pipeline(
"text-generation", model=model, tokenizer=tokenizer, device=device
)
return generation