cointegrated commited on
Commit
1f2d9ad
1 Parent(s): 735822e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -7
README.md CHANGED
@@ -11,11 +11,15 @@ from transformers import AutoModelForSequenceClassification, BertTokenizer
11
  model_name = 'cointegrated/rubert-base-cased-dp-paraphrase-detection'
12
  model = AutoModelForSequenceClassification.from_pretrained(model_name).cuda()
13
  tokenizer = BertTokenizer.from_pretrained(model_name)
14
- text1 = 'Сегодня на улице хорошая погода'
15
- text2 = 'Сегодня на улице отвратительная погода'
16
- batch = tokenizer(text1, text2, return_tensors='pt').to(model.device)
17
- with torch.inference_mode():
18
- proba = torch.softmax(model(**batch).logits, -1).cpu().numpy()
19
- print(proba)
20
- # [[0.44876656 0.5512334 ]]
 
 
 
 
21
  ```
 
11
  model_name = 'cointegrated/rubert-base-cased-dp-paraphrase-detection'
12
  model = AutoModelForSequenceClassification.from_pretrained(model_name).cuda()
13
  tokenizer = BertTokenizer.from_pretrained(model_name)
14
+
15
+ def compare_texts(text1, text2):
16
+ batch = tokenizer(text1, text2, return_tensors='pt').to(model.device)
17
+ with torch.inference_mode():
18
+ proba = torch.softmax(model(**batch).logits, -1).cpu().numpy()
19
+ return proba[0] # p(non-paraphrase), p(paraphrase)
20
+
21
+ print(compare_texts('Сегодня на улице хорошая погода', 'Сегодня на улице отвратительная погода'))
22
+ # [0.7056226 0.2943774]
23
+ print(compare_texts('Сегодня на улице хорошая погода', 'Отличная погодка сегодня выдалась'))
24
+ # [0.16524374 0.8347562 ]
25
  ```