import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
# Добавление слайдера
temperature = st.slider("Выберите градус недоверия", 1.0, 20.0, 1.0)

st.title("""
 # History Mistery
 """)
# image = Image.open('data-scins.jpeg')

# st.image(image, caption='Current mood')

# Загрузка модели и токенизатора
model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
# model_finetuned = GPT2LMHeadModel.from_pretrained(
#     'sberbank-ai/rugpt3small_based_on_gpt2',
#     output_attentions = False,
#     output_hidden_states = False,
# )

# # Вешаем сохраненные веса на нашу модель
# model_finetuned.load_state_dict(torch.load('model_hostory.pt'))
# Функция для генерации текста
def generate_text(prompt):
    # Преобразование входной строки в токены
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    # Генерация текста
    output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
                            temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
                            num_return_sequences=3)

    # Декодирование сгенерированного текста
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return generated_text

# Streamlit приложение
def main():
    st.write("""
    # GPT-3 генерация текста
    """)

    # Ввод строки пользователем
    prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")

    # Генерация текста по введенной строке
    generated_text = generate_text(prompt)
    # Создание кнопки "Сгенерировать"
generate_button = st.button("Сгенерировать")

    # Обработка события нажатия кнопки
    if generate_button:
    # Генерация текста на основе пользовательского ввода
        generated_text = generate_text(user_input)
    # Вывод сгенерированного текста
        st.subheader("Продолжение:")
        st.write(generated_text)

if __name__ == "__main__":
    main()