|
--- |
|
tags: Text generation |
|
|
|
license: unlicense |
|
|
|
language: ru |
|
|
|
widget: |
|
- text: "- Как тебя зовут? - Джульетта Мао #" |
|
- text: "- А живешь где? - В поясе астероидов #" |
|
--- |
|
|
|
|
|
## Задача Incomplete Utterance Restoration |
|
|
|
Генеративная модель на основе [sberbank-ai/rugpt3large_based_on_gpt2](https://huggingface.co/sberbank-ai/rugpt3large_based_on_gpt2) |
|
для восстановления полного текста реплик в диалоге из контекста. |
|
|
|
Допустим, последние 2 строки диалога имеют вид: |
|
|
|
``` |
|
- Как тебя зовут? |
|
- Джульетта Мао |
|
``` |
|
|
|
Модель позволяет получить полный текст последней реплики, в раскрытом виде: |
|
|
|
``` |
|
Меня зовут Джульетта Мао |
|
``` |
|
|
|
Раскрытая реплика позволяет использовать многие классические инструменты NLP для обработки, включая регулярные выражения, классификаторы |
|
интентов и т.д. |
|
|
|
|
|
## Пример использования |
|
|
|
|
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter") |
|
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'}) |
|
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter") |
|
model.to(device) |
|
|
|
# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-" |
|
# В конце добавляем символ "#" |
|
input_text = """<s>- Как тебя зовут? |
|
- Джульетта Мао #""" |
|
#input_text = """<s>- Что Предтечи забрали у Предшественников? |
|
#- Они узурпировали у них Мантию — защиту всего живого в галактике #""" |
|
|
|
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device) |
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt, |
|
max_length=100, |
|
temperature=1.0, |
|
top_k=30, |
|
top_p=0.85, |
|
repetition_penalty=1.2, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:] |
|
text = text[: text.find('</s>')] |
|
print(text) |
|
``` |
|
|