Update README.md
Browse files
README.md
CHANGED
@@ -32,81 +32,15 @@ The preferable usage:
|
|
32 |
|
33 |
```python
|
34 |
# pip install transformers
|
35 |
-
import
|
36 |
-
from transformers import AutoTokenizer, AutoModel
|
37 |
import torch
|
38 |
-
|
39 |
-
|
40 |
-
model = AutoModel.from_pretrained("tinkoff-ai/response-quality-classifier-tiny")
|
41 |
# model.cuda()
|
42 |
-
|
43 |
-
|
44 |
-
context_1 = 'как дела?'
|
45 |
-
response = 'у меня все хорошо, а у тебя как?'
|
46 |
-
|
47 |
-
sample = {
|
48 |
-
'context_3': context_3,
|
49 |
-
'context_2': context_2,
|
50 |
-
'context_1': context_1,
|
51 |
-
'response': response
|
52 |
-
}
|
53 |
-
|
54 |
-
SEP_TOKEN = '[SEP]'
|
55 |
-
CLS_TOKEN = '[CLS]'
|
56 |
-
RESPONSE_TOKEN = '[RESPONSE_TOKEN]'
|
57 |
-
MAX_SEQ_LENGTH = 128
|
58 |
-
sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response']
|
59 |
-
|
60 |
-
def tokenize_dialog_data(
|
61 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
62 |
-
sample: Dict,
|
63 |
-
max_seq_length: int,
|
64 |
-
sorted_dialog_columns: List,
|
65 |
-
):
|
66 |
-
"""
|
67 |
-
Tokenize both contexts and response of dialog data separately
|
68 |
-
"""
|
69 |
-
len_message_history = len(sorted_dialog_columns)
|
70 |
-
max_seq_length = min(max_seq_length, tokenizer.model_max_length)
|
71 |
-
max_each_message_length = max_seq_length // len_message_history - 1
|
72 |
-
messages = [sample[k] for k in sorted_dialog_columns]
|
73 |
-
result = {model_input_name: [] for model_input_name in tokenizer.model_input_names}
|
74 |
-
messages = [str(message) if message is not None else '' for message in messages]
|
75 |
-
tokens = tokenizer(
|
76 |
-
messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False
|
77 |
-
)
|
78 |
-
for model_input_name in tokens.keys():
|
79 |
-
result[model_input_name].extend(tokens[model_input_name])
|
80 |
-
return result
|
81 |
-
|
82 |
-
def merge_dialog_data(
|
83 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
84 |
-
sample: Dict
|
85 |
-
):
|
86 |
-
cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False)
|
87 |
-
sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False)
|
88 |
-
response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False)
|
89 |
-
model_input_names = tokenizer.model_input_names
|
90 |
-
result = {}
|
91 |
-
for model_input_name in model_input_names:
|
92 |
-
tokens = []
|
93 |
-
tokens.extend(cls_token[model_input_name])
|
94 |
-
for i, message in enumerate(sample[model_input_name]):
|
95 |
-
tokens.extend(message)
|
96 |
-
if i < len(sample[model_input_name]) - 2:
|
97 |
-
tokens.extend(sep_token[model_input_name])
|
98 |
-
elif i == len(sample[model_input_name]) - 2:
|
99 |
-
tokens.extend(response_token[model_input_name])
|
100 |
-
result[model_input_name] = torch.tensor([tokens])
|
101 |
-
if torch.cuda.is_available():
|
102 |
-
result[model_input_name] = result[model_input_name].cuda()
|
103 |
-
return result
|
104 |
-
|
105 |
-
tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
|
106 |
-
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
|
107 |
with torch.inference_mode():
|
108 |
-
logits = model(**
|
109 |
probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
|
110 |
-
|
111 |
print(probas)
|
112 |
```
|
|
|
32 |
|
33 |
```python
|
34 |
# pip install transformers
|
35 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
36 |
import torch
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained("/mnt/chatbot_models2/chit-chat/experiments/crossencoder_hf/rubert-base-sentence/dialogs_whole")
|
38 |
+
model = AutoModelForSequenceClassification.from_pretrained("/mnt/chatbot_models2/chit-chat/experiments/crossencoder_hf/rubert-base-sentence/dialogs_whole")
|
|
|
39 |
# model.cuda()
|
40 |
+
inputs = tokenizer('привет[SEP]привет![SEP]как дела?[RESPONSE_TOKEN]норм',
|
41 |
+
padding=True, max_length=128, truncation=True, add_special_tokens=False, return_tensors='pt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
with torch.inference_mode():
|
43 |
+
logits = model(**inputs).logits
|
44 |
probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
|
|
|
45 |
print(probas)
|
46 |
```
|