rugpt_interpreter / README.md
koziev ilya
Добавлен базовый сценарий использования
513d6cb
|
raw
history blame
2.42 kB
metadata
license: unlicense

Задача Incomplete Utterance Restoration

Генеративная модель на основе sberbank-ai/rugpt3large_based_on_gpt2 для восстановления полного текста реплик в диалоге из контекста.

Допустим, последние 2 строки диалога имеют вид:

- Как тебя зовут?
- Джульетта Мао

Модель позволяет получить полный текст последней реплики, в раскрытом виде:

Меня зовут Джульетта Мао

Раскрытая реплика позволяет использовать многие классические инструменты NLP для обработки, включая регулярные выражения, классификаторы интентов и т.д.

Пример использования

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
model.to(device)

# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
# В конце добавляем символ "#"
input_text = """<s>- Как тебя зовут?
- Джульетта Мао #"""
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)

output_sequences = model.generate(
    input_ids=encoded_prompt,
    max_length=100,
    temperature=1.0,
    top_k=30,
    top_p=0.85,
    repetition_penalty=1.2,
    do_sample=True,
    num_return_sequences=1,
    pad_token_id=0
)

generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
text = text[: text.find('</s>')]
text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
print(text)