Spaces:
Sleeping
Sleeping
File size: 3,985 Bytes
8e698f5 7a5f863 8e698f5 7a5f863 8e698f5 1bf4b30 8e698f5 3d83edc 8e698f5 0f694cf d5e16b4 8e698f5 7eb2542 a0e7286 425069f 7f7732b a0e7286 7f7732b a0e7286 7f7732b 8e698f5 3712cee 8e698f5 3f4021d 8e698f5 d5e16b4 5ec68f4 8e698f5 3f4021d 8e698f5 a0e7286 8e698f5 224ced6 8e698f5 a0e7286 8e698f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
import base64
import plotly.express as px
df = px.data.iris()
@st.cache_data
def get_img_as_base64(file):
with open(file, "rb") as f:
data = f.read()
return base64.b64encode(data).decode()
page_bg_img = f"""
<style>
[data-testid="stAppViewContainer"] > .main {{
background-image: url("https://wallpapercave.com/wp/wp6480460.jpg");
background-size: 115%;
background-position: top left;
background-repeat: no-repeat;
background-attachment: local;
}}
[data-testid="stSidebar"] > div:first-child {{
background-image: url("https://ibb.co/ZBkdJRg");
background-size: 115%;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
}}
[data-testid="stHeader"] {{
background: rgba(0,0,0,0);
}}
[data-testid="stToolbar"] {{
right: 2rem;
}}
div.css-1n76uvr.e1tzin5v0 {{
background-color: rgba(238, 238, 238, 0.5);
border: 10px solid #EEEEEE;
padding: 5% 5% 5% 10%;
border-radius: 5px;
}}
</style>
"""
st.markdown(page_bg_img, unsafe_allow_html=True)
st.title("""
History Mystery
""")
# Добавление слайдера
temp = st.slider("Градус дичи", 1.0, 20.0, 1.0)
sen_quan = st.slider(" Длина сгенерированного отрывка", 2, 10, 5)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
#@st.cache_resource(allow_output_mutation=True)
def load_gpt():
model_GPT = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions=False,
output_hidden_states=False,
)
tokenizer_GPT = GPT2Tokenizer.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions=False,
output_hidden_states=False,
)
model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
return model_GPT, tokenizer_GPT
#model, tokenizer = load_gpt()
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
# Преобразование входной строки в токены
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
# Генерация текста
output = model_GPT.generate(input_ids=input_ids, max_length=80, num_beams=5, do_sample=True,
temperature=temp, top_k=50, top_p=0.6, no_repeat_ngram_size=4,
num_return_sequences=sen_quan)
# Декодирование сгенерированного текста
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_GPT, tokenizer_GPT = load_gpt()
st.write("""
# GPT-3 генерация текста
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на Руси")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button = st.button("За работу!")
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
st.subheader("Продолжение:")
st.write(generated_text)
if __name__ == "__main__":
main()
|