Spaces:
Runtime error
Runtime error
File size: 1,327 Bytes
e79483d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from safetensors.torch import load_model
from transformers import RobertaTokenizer, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
def get_roberta():
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base')
load_model(model, "roberta.safetensors")
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
return tokenizer, model
def get_gpt():
model = GPT2ForSequenceClassification.from_pretrained('gpt2', num_labels=3)
model.config.pad_token_id = model.config.eos_token_id
load_model(model, "gpt.safetensors")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
def get_distilbert():
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
load_model(model, "distilbert.safetensors")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
return tokenizer, model |