egoriya's picture
Update README.md
8539235
|
raw
history blame
4.6 kB
metadata
license: mit
widget:
  - text: >-
      привет[SEP]привет![SEP]как дела?[RESPONSE_TOKEN]супер, вот только
      проснулся, у тебя как?
    example_title: Dialog example 1
  - text: привет[SEP]привет![SEP]как дела?[RESPONSE_TOKEN]норм
    example_title: Dialog example 2
  - text: привет[SEP]привет![SEP]как дела?[RESPONSE_TOKEN]норм, у тя как?
    example_title: Dialog example 3

This classification model is based on cointegrated/rubert-tiny2. The model should be used to produce relevance and specificity of the last message in the context of a dialogue.

The labels explanation:

  • relevance: is the last message in the dialogue relevant in the context of the full dialogue
  • specificity: is the last message in the dialogue interesting and promotes the continuation of the dialogue

The preferable length of the dialogue is 4 where the last message is needed to be estimated

It is pretrained on corpus of dialog data from social networks and finetuned on tinkoff-ai/context_similarity. The performance of the model on validation split tinkoff-ai/context_similarity (with the best thresholds for validation samples):

f0.5 ROC AUC
relevance 0.82 0.74
specificity 0.81 0.8

The preferable usage:

# pip install transformers
import transformers
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List, Dict
tokenizer = AutoTokenizer.from_pretrained("tinkoff-ai/response-quality-classifier-tiny")
model = AutoModel.from_pretrained("tinkoff-ai/response-quality-classifier-tiny")
# model.cuda()
context_3 = 'привет'
context_2 = 'привет!'
context_1 = 'как дела?'
response = 'у меня все хорошо, а у тебя как?'

sample = {
    'context_3': context_3,
    'context_2': context_2,
    'context_1': context_1,
    'response': response
}

SEP_TOKEN = '[SEP]'
CLS_TOKEN = '[CLS]'
RESPONSE_TOKEN = '[RESPONSE_TOKEN]'
MAX_SEQ_LENGTH = 128
sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response']

def tokenize_dialog_data(
        tokenizer: transformers.PreTrainedTokenizer,
        sample: Dict,
        max_seq_length: int,
        sorted_dialog_columns: List,
):
    """
    Tokenize both contexts and response of dialog data separately
    """
    len_message_history = len(sorted_dialog_columns)
    max_seq_length = min(max_seq_length, tokenizer.model_max_length)
    max_each_message_length = max_seq_length // len_message_history - 1
    messages = [sample[k] for k in sorted_dialog_columns]
    result = {model_input_name: [] for model_input_name in tokenizer.model_input_names}
    messages = [str(message) if message is not None else '' for message in messages]
    tokens = tokenizer(
        messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False
    )
    for model_input_name in tokens.keys():
        result[model_input_name].extend(tokens[model_input_name])
    return result

def merge_dialog_data(
        tokenizer: transformers.PreTrainedTokenizer,
        sample: Dict
):
    cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False)
    sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False)
    response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False)
    model_input_names = tokenizer.model_input_names
    result = {}
    for model_input_name in model_input_names:
        tokens = []
        tokens.extend(cls_token[model_input_name])
        for i, message in enumerate(sample[model_input_name]):
            tokens.extend(message)
            if i < len(sample[model_input_name]) - 2:
                tokens.extend(sep_token[model_input_name])
            elif i == len(sample[model_input_name]) - 2:
                tokens.extend(response_token[model_input_name])
        result[model_input_name] = torch.tensor([tokens])
        if torch.cuda.is_available():
            result[model_input_name] = result[model_input_name].cuda()
    return result

tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
with torch.inference_mode():
    logits = model(**tokens).logits
    probas = torch.sigmoid(logits)[0].cpu().detach().numpy()

print(probas)