History / app.py
SaviAnna's picture
Update app.py
1a8c51c
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
st.title("""
History Mystery
""")
# Добавление слайдера
temperature = st.slider("Градус дичи", 1, 20, 1)
max_len = st.slider(" Длина сгенерированного отрывка", 40, 120, 2)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
@st.cache
def load_gpt():
model_GPT = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
tokenizer_GPT = GPT2Tokenizer.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
return model_GPT, tokenizer_GPT
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
# Преобразование входной строки в токены
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
# Генерация текста
output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3,
num_return_sequences=3)
# Декодирование сгенерированного текста
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_GPT, tokenizer_GPT = load_gpt()
st.write("""
# GPT-3 генерация текста
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button = st.button("За работу!")
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
st.subheader("Продолжение:")
st.write(generated_text)
if __name__ == "__main__":
main()