Snizhanna commited on
Commit
b227c5c
·
verified ·
1 Parent(s): 4031604

Update utils_models.py

Browse files
Files changed (1) hide show
  1. utils_models.py +27 -27
utils_models.py CHANGED
@@ -1,27 +1,27 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
-
4
- def map_num_to_label(num):
5
- return "сарказм" if num==1 else "не сарказм"
6
-
7
- def load_roberta():
8
- model_ckpt = "ukr-roberta-base-finetuned-sarc"
9
- tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
10
- id2label = {1: "sarcastic",0: "not_sarcastic"}
11
- label2id = {"sarcastic": 1, "not_sarcastic": 0}
12
- hf_model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=2, label2id=label2id, id2label=id2label)
13
- return hf_model, tokenizer
14
-
15
- def predict_roberta(model, tokenizer, text):
16
- tokenized_input = tokenizer(text, return_tensors="pt")
17
- predictions = model(**tokenized_input)
18
- prediction = predictions.logits.argmax().item()
19
- return map_num_to_label(prediction)
20
-
21
- def identity_tokenizer(text):
22
- return text
23
-
24
- def predict_lr_rf(model, vectorizer, text):
25
- prediction = model.predict(vectorizer.transform([text]))
26
- return map_num_to_label(prediction)
27
-
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+ def map_num_to_label(num):
5
+ return "сарказм" if num==1 else "не сарказм"
6
+
7
+ def load_roberta():
8
+ model_ckpt = "Snizhanna/ukr-roberta-base-finetuned-sarc"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
10
+ id2label = {1: "sarcastic",0: "not_sarcastic"}
11
+ label2id = {"sarcastic": 1, "not_sarcastic": 0}
12
+ hf_model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=2, label2id=label2id, id2label=id2label)
13
+ return hf_model, tokenizer
14
+
15
+ def predict_roberta(model, tokenizer, text):
16
+ tokenized_input = tokenizer(text, return_tensors="pt")
17
+ predictions = model(**tokenized_input)
18
+ prediction = predictions.logits.argmax().item()
19
+ return map_num_to_label(prediction)
20
+
21
+ def identity_tokenizer(text):
22
+ return text
23
+
24
+ def predict_lr_rf(model, vectorizer, text):
25
+ prediction = model.predict(vectorizer.transform([text]))
26
+ return map_num_to_label(prediction)
27
+