d.tsimerman
round
8f0e355
raw
history blame contribute delete
No virus
4.63 kB
import os
import streamlit as st
import transformers
import torch
import tokenizers
from typing import List, Dict
st.subheader('Эта демонстрация позволяет поэксперементировать с моделями, которые оценивают, насколько предлагаемый ответ подходит к контексту диалога.')
model_name = st.selectbox(
'Выберите модель',
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large')
)
auth_token = os.environ.get('TOKEN') or True
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
def load_model(model_name: str):
with st.spinner('Loading models...'):
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token)
if torch.cuda.is_available():
model = model.cuda()
return tokenizer, model
context_3 = 'привет'
context_2 = 'привет!'
context_1 = 'как дела?'
st.markdown('👱🏻‍♀️ **Настя**: ' + context_3)
st.markdown('🤖 **Диалоговый агент**: ' + context_2)
st.markdown('👱🏻‍♀️ **Настя**: ' + context_1)
response = st.text_input('🤖 Диалоговый агент:', 'норм')
sample = {
'context_3': context_3,
'context_2': context_2,
'context_1': context_1,
'response': response
}
SEP_TOKEN = '[SEP]'
CLS_TOKEN = '[CLS]'
RESPONSE_TOKEN = '[RESPONSE_TOKEN]'
MAX_SEQ_LENGTH = 128
sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response']
def tokenize_dialog_data(
tokenizer: transformers.PreTrainedTokenizer,
sample: Dict,
max_seq_length: int,
sorted_dialog_columns: List,
):
"""
Tokenize both contexts and response of dialog data separately
"""
len_message_history = len(sorted_dialog_columns)
max_seq_length = min(max_seq_length, tokenizer.model_max_length)
max_each_message_length = max_seq_length // len_message_history - 1
messages = [sample[k] for k in sorted_dialog_columns]
result = {model_input_name: [] for model_input_name in tokenizer.model_input_names}
messages = [str(message) if message is not None else '' for message in messages]
tokens = tokenizer(
messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False
)
for model_input_name in tokens.keys():
result[model_input_name].extend(tokens[model_input_name])
return result
def merge_dialog_data(
tokenizer: transformers.PreTrainedTokenizer,
sample: Dict
):
cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False)
sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False)
response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False)
model_input_names = tokenizer.model_input_names
result = {}
for model_input_name in model_input_names:
tokens = []
tokens.extend(cls_token[model_input_name])
for i, message in enumerate(sample[model_input_name]):
tokens.extend(message)
if i < len(sample[model_input_name]) - 2:
tokens.extend(sep_token[model_input_name])
elif i == len(sample[model_input_name]) - 2:
tokens.extend(response_token[model_input_name])
result[model_input_name] = torch.tensor([tokens])
if torch.cuda.is_available():
result[model_input_name] = result[model_input_name].cuda()
return result
@st.cache
def inference(model_name: str, sample: dict):
tokenizer, model = load_model(model_name)
tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
with torch.inference_mode():
logits = model(**tokens).logits
probas = torch.sigmoid(logits)[0].cpu().detach().numpy().tolist()
return probas
with st.spinner('Running inference...'):
probas = inference(model_name, sample)
st.metric(
label='Вероятность того, что последний ответ диалогового агента релевантный',
value=round(probas[0], 3)
)
st.metric(
label='Вероятность того, что последний ответ диалогового агента вовлечённый',
value=round(probas[1], 3)
)