rugpt_interpreter / README.md
koziev ilya
немного причесал код, убрал лишние манипуляции с выдачей gpt
e3f5def
|
raw
history blame
2.71 kB
metadata
tags: Text generation
license: unlicense
language: ru
widget:
  - text: '- Как тебя зовут? - Джульетта Мао #'
  - text: '- А живешь где? - В поясе астероидов #'

Задача Incomplete Utterance Restoration

Генеративная модель на основе sberbank-ai/rugpt3large_based_on_gpt2 для восстановления полного текста реплик в диалоге из контекста.

Допустим, последние 2 строки диалога имеют вид:

- Как тебя зовут?
- Джульетта Мао

Модель позволяет получить полный текст последней реплики, в раскрытом виде:

Меня зовут Джульетта Мао

Раскрытая реплика позволяет использовать многие классические инструменты NLP для обработки, включая регулярные выражения, классификаторы интентов и т.д.

Пример использования

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
model.to(device)

# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
# В конце добавляем символ "#"
input_text = """<s>- Как тебя зовут?
- Джульетта Мао #"""
#input_text = """<s>- Что Предтечи забрали у Предшественников?
#- Они узурпировали у них Мантию — защиту всего живого в галактике #"""

encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)

output_sequences = model.generate(
    input_ids=encoded_prompt,
    max_length=100,
    temperature=1.0,
    top_k=30,
    top_p=0.85,
    repetition_penalty=1.2,
    do_sample=True,
    num_return_sequences=1,
    pad_token_id=tokenizer.pad_token_id,
)

text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
text = text[: text.find('</s>')]
print(text)