Phi2_RCT1M-ft-heading / inference.py
SaborDay's picture
Creating inference script
b24bef1 verified
raw
history blame
402 Bytes
from transformers import AutoTokenizer
import transformers
import torch
from peft import PeftModel, PeftConfig
#Load the model weights from hub
model_id = "SaborDay/Phi2_RCT1M-ft-heading"
trained_model = PeftModel.from_pretrained(model, model_id)
#Run inference
outputs = trained_model.generate(**inputs, max_length=1000)
text = tokenizer.batch_decode(outputs,skip_special_tokens=True)[0]
print(text)