cointegrated
commited on
Commit
•
1f2d9ad
1
Parent(s):
735822e
Update README.md
Browse files
README.md
CHANGED
@@ -11,11 +11,15 @@ from transformers import AutoModelForSequenceClassification, BertTokenizer
|
|
11 |
model_name = 'cointegrated/rubert-base-cased-dp-paraphrase-detection'
|
12 |
model = AutoModelForSequenceClassification.from_pretrained(model_name).cuda()
|
13 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
14 |
-
|
15 |
-
|
16 |
-
batch = tokenizer(text1, text2, return_tensors='pt').to(model.device)
|
17 |
-
with torch.inference_mode():
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
```
|
|
|
11 |
model_name = 'cointegrated/rubert-base-cased-dp-paraphrase-detection'
|
12 |
model = AutoModelForSequenceClassification.from_pretrained(model_name).cuda()
|
13 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
14 |
+
|
15 |
+
def compare_texts(text1, text2):
|
16 |
+
batch = tokenizer(text1, text2, return_tensors='pt').to(model.device)
|
17 |
+
with torch.inference_mode():
|
18 |
+
proba = torch.softmax(model(**batch).logits, -1).cpu().numpy()
|
19 |
+
return proba[0] # p(non-paraphrase), p(paraphrase)
|
20 |
+
|
21 |
+
print(compare_texts('Сегодня на улице хорошая погода', 'Сегодня на улице отвратительная погода'))
|
22 |
+
# [0.7056226 0.2943774]
|
23 |
+
print(compare_texts('Сегодня на улице хорошая погода', 'Отличная погодка сегодня выдалась'))
|
24 |
+
# [0.16524374 0.8347562 ]
|
25 |
```
|