File size: 3,985 Bytes
8e698f5
7a5f863
8e698f5
7a5f863
8e698f5
 
1bf4b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e698f5
3d83edc
8e698f5
 
0f694cf
d5e16b4
8e698f5
 
 
 
7eb2542
a0e7286
 
425069f
7f7732b
 
a0e7286
 
 
7f7732b
 
 
a0e7286
 
7f7732b
 
8e698f5
 
3712cee
8e698f5
3f4021d
8e698f5
d5e16b4
5ec68f4
 
8e698f5
3f4021d
8e698f5
 
 
a0e7286
8e698f5
 
 
 
224ced6
8e698f5
 
 
 
 
 
 
a0e7286
8e698f5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch

import base64
import plotly.express as px

df = px.data.iris()

@st.cache_data
def get_img_as_base64(file):
    with open(file, "rb") as f:
        data = f.read()
    return base64.b64encode(data).decode()


page_bg_img = f"""
<style>
[data-testid="stAppViewContainer"] > .main {{
background-image: url("https://wallpapercave.com/wp/wp6480460.jpg");
background-size: 115%;
background-position: top left;
background-repeat: no-repeat;
background-attachment: local;
}}

[data-testid="stSidebar"] > div:first-child {{
background-image: url("https://ibb.co/ZBkdJRg");
background-size: 115%;
background-position: center; 
background-repeat: no-repeat;
background-attachment: fixed;
}}

[data-testid="stHeader"] {{
background: rgba(0,0,0,0);
}}

[data-testid="stToolbar"] {{
right: 2rem;
}}

div.css-1n76uvr.e1tzin5v0 {{
background-color: rgba(238, 238, 238, 0.5);
border: 10px solid #EEEEEE;
padding: 5% 5% 5% 10%;
border-radius: 5px;
}}


</style>
"""
st.markdown(page_bg_img, unsafe_allow_html=True)

st.title("""
 History Mystery
 """)
# Добавление слайдера
temp = st.slider("Градус дичи", 1.0, 20.0, 1.0)
sen_quan = st.slider(" Длина сгенерированного отрывка", 2, 10, 5)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
#@st.cache_resource(allow_output_mutation=True)
def load_gpt():
    model_GPT = GPT2LMHeadModel.from_pretrained(
     'sberbank-ai/rugpt3small_based_on_gpt2',
     output_attentions=False,
     output_hidden_states=False,
    )
    tokenizer_GPT = GPT2Tokenizer.from_pretrained(
        'sberbank-ai/rugpt3small_based_on_gpt2',
        output_attentions=False,
        output_hidden_states=False,
    )
    model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
    return model_GPT, tokenizer_GPT

#model, tokenizer = load_gpt()
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
    # Преобразование входной строки в токены
    input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
    # Генерация текста
    output = model_GPT.generate(input_ids=input_ids, max_length=80, num_beams=5, do_sample=True,
                            temperature=temp, top_k=50, top_p=0.6, no_repeat_ngram_size=4,
                            num_return_sequences=sen_quan)
    # Декодирование сгенерированного текста
    generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
    return generated_text
# Streamlit приложение
def main():
    model_GPT, tokenizer_GPT = load_gpt()
    st.write("""
    # GPT-3 генерация текста
    """)
    # Ввод строки пользователем
    prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на Руси")
    # # Генерация текста по введенной строке
    # generated_text = generate_text(prompt)
    # Создание кнопки "Сгенерировать"
    generate_button = st.button("За работу!")
    # Обработка события нажатия кнопки
    if generate_button:
    # Вывод сгенерированного текста
        generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
        st.subheader("Продолжение:")
        st.write(generated_text)
if __name__ == "__main__":
    main()