metadata
license: unlicense
Задача Incomplete Utterance Restoration
Генеративная модель на основе sberbank-ai/rugpt3large_based_on_gpt2 для восстановления полного текста реплик в диалоге из контекста.
Допустим, последние 2 строки диалога имеют вид:
- Как тебя зовут?
- Джульетта Мао
Модель позволяет получить полный текст последней реплики, в раскрытом виде:
Меня зовут Джульетта Мао
Раскрытая реплика позволяет использовать многие классические инструменты NLP для обработки, включая регулярные выражения, классификаторы интентов и т.д.
Пример использования
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
model.to(device)
# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
# В конце добавляем символ "#"
input_text = """<s>- Как тебя зовут?
- Джульетта Мао #"""
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=100,
temperature=1.0,
top_k=30,
top_p=0.85,
repetition_penalty=1.2,
do_sample=True,
num_return_sequences=1,
pad_token_id=0
)
generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
text = text[: text.find('</s>')]
text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
print(text)