Clemet commited on
Commit
e79483d
1 Parent(s): f2a699d

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +36 -0
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_model
2
+ from transformers import RobertaTokenizer, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
3
+ from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
4
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
5
+
6
+
7
+
8
+
9
+ def get_roberta():
10
+
11
+ model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base')
12
+ load_model(model, "roberta.safetensors")
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
15
+
16
+ return tokenizer, model
17
+
18
+ def get_gpt():
19
+
20
+ model = GPT2ForSequenceClassification.from_pretrained('gpt2', num_labels=3)
21
+ model.config.pad_token_id = model.config.eos_token_id
22
+ load_model(model, "gpt.safetensors")
23
+
24
+ tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ return tokenizer, model
28
+
29
+ def get_distilbert():
30
+
31
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
32
+ load_model(model, "distilbert.safetensors")
33
+
34
+ tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
35
+
36
+ return tokenizer, model