Spaces:
Sleeping
Sleeping
import transformers | |
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import numpy as np | |
from PIL import Image | |
import torch | |
import base64 | |
import plotly.express as px | |
df = px.data.iris() | |
def get_img_as_base64(file): | |
with open(file, "rb") as f: | |
data = f.read() | |
return base64.b64encode(data).decode() | |
page_bg_img = f""" | |
<style> | |
[data-testid="stAppViewContainer"] > .main {{ | |
background-image: url("https://wallpapercave.com/wp/wp6480460.jpg"); | |
background-size: 115%; | |
background-position: top left; | |
background-repeat: no-repeat; | |
background-attachment: local; | |
}} | |
[data-testid="stSidebar"] > div:first-child {{ | |
background-image: url("https://ibb.co/ZBkdJRg"); | |
background-size: 115%; | |
background-position: center; | |
background-repeat: no-repeat; | |
background-attachment: fixed; | |
}} | |
[data-testid="stHeader"] {{ | |
background: rgba(0,0,0,0); | |
}} | |
[data-testid="stToolbar"] {{ | |
right: 2rem; | |
}} | |
div.css-1n76uvr.e1tzin5v0 {{ | |
background-color: rgba(238, 238, 238, 0.5); | |
border: 10px solid #EEEEEE; | |
padding: 5% 5% 5% 10%; | |
border-radius: 5px; | |
}} | |
</style> | |
""" | |
st.markdown(page_bg_img, unsafe_allow_html=True) | |
st.title(""" | |
History Mystery | |
""") | |
# Добавление слайдера | |
temp = st.slider("Градус дичи", 1.0, 20.0, 1.0) | |
sen_quan = st.slider(" Длина сгенерированного отрывка", 2, 10, 5) | |
# Загрузка модели и токенизатора | |
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') | |
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') | |
# #Задаем класс модели (уже в streamlit/tg_bot) | |
#@st.cache_resource(allow_output_mutation=True) | |
def load_gpt(): | |
model_GPT = GPT2LMHeadModel.from_pretrained( | |
'sberbank-ai/rugpt3small_based_on_gpt2', | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
tokenizer_GPT = GPT2Tokenizer.from_pretrained( | |
'sberbank-ai/rugpt3small_based_on_gpt2', | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu'))) | |
return model_GPT, tokenizer_GPT | |
#model, tokenizer = load_gpt() | |
# # Вешаем сохраненные веса на нашу модель | |
# Функция для генерации текста | |
def generate_text(model_GPT, tokenizer_GPT, prompt): | |
# Преобразование входной строки в токены | |
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt') | |
# Генерация текста | |
output = model_GPT.generate(input_ids=input_ids, max_length=80, num_beams=5, do_sample=True, | |
temperature=temp, top_k=50, top_p=0.6, no_repeat_ngram_size=4, | |
num_return_sequences=sen_quan) | |
# Декодирование сгенерированного текста | |
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Streamlit приложение | |
def main(): | |
model_GPT, tokenizer_GPT = load_gpt() | |
st.write(""" | |
# GPT-3 генерация текста | |
""") | |
# Ввод строки пользователем | |
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на Руси") | |
# # Генерация текста по введенной строке | |
# generated_text = generate_text(prompt) | |
# Создание кнопки "Сгенерировать" | |
generate_button = st.button("За работу!") | |
# Обработка события нажатия кнопки | |
if generate_button: | |
# Вывод сгенерированного текста | |
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt) | |
st.subheader("Продолжение:") | |
st.write(generated_text) | |
if __name__ == "__main__": | |
main() | |