File size: 2,985 Bytes
8e698f5
7a5f863
8e698f5
7a5f863
8e698f5
 
 
3d83edc
8e698f5
 
0f694cf
d5e16b4
8e698f5
 
 
 
7eb2542
a0e7286
 
425069f
7f7732b
 
a0e7286
 
 
7f7732b
 
 
a0e7286
 
7f7732b
 
8e698f5
 
3712cee
8e698f5
3f4021d
8e698f5
d5e16b4
5ec68f4
 
8e698f5
3f4021d
8e698f5
 
 
a0e7286
8e698f5
 
 
 
224ced6
8e698f5
 
 
 
 
 
 
a0e7286
8e698f5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
st.title("""
 History Mystery
 """)
# Добавление слайдера
temp = st.slider("Градус дичи", 1.0, 20.0, 1.0)
sen_quan = st.slider(" Длина сгенерированного отрывка", 2, 10, 5)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
#@st.cache_resource(allow_output_mutation=True)
def load_gpt():
    model_GPT = GPT2LMHeadModel.from_pretrained(
     'sberbank-ai/rugpt3small_based_on_gpt2',
     output_attentions=False,
     output_hidden_states=False,
    )
    tokenizer_GPT = GPT2Tokenizer.from_pretrained(
        'sberbank-ai/rugpt3small_based_on_gpt2',
        output_attentions=False,
        output_hidden_states=False,
    )
    model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
    return model_GPT, tokenizer_GPT

#model, tokenizer = load_gpt()
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
    # Преобразование входной строки в токены
    input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
    # Генерация текста
    output = model_GPT.generate(input_ids=input_ids, max_length=80, num_beams=5, do_sample=True,
                            temperature=temp, top_k=50, top_p=0.6, no_repeat_ngram_size=4,
                            num_return_sequences=sen_quan)
    # Декодирование сгенерированного текста
    generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
    return generated_text
# Streamlit приложение
def main():
    model_GPT, tokenizer_GPT = load_gpt()
    st.write("""
    # GPT-3 генерация текста
    """)
    # Ввод строки пользователем
    prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на Руси")
    # # Генерация текста по введенной строке
    # generated_text = generate_text(prompt)
    # Создание кнопки "Сгенерировать"
    generate_button = st.button("За работу!")
    # Обработка события нажатия кнопки
    if generate_button:
    # Вывод сгенерированного текста
        generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
        st.subheader("Продолжение:")
        st.write(generated_text)
if __name__ == "__main__":
    main()