import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Загрузка модели и токенизатора
@st.cache_resource
def load_model():
model_name = "models/gpt"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def generate_text(model, tokenizer, prompt, gen_params):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=gen_params['max_length'],
temperature=gen_params['temperature'],
top_k=gen_params['top_k'],
top_p=gen_params['top_p'],
num_return_sequences=gen_params['num_return_sequences'],
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated = []
for i, output in enumerate(outputs):
text = tokenizer.decode(output, skip_special_tokens=True)
generated.append(f"Генерация {i+1}:\n{text}\n{'-'*50}")
return generated
def main():
st.markdown(
"
Генератор текста
",
unsafe_allow_html=True
)
st.markdown(
"(ну почти)
",
unsafe_allow_html=True
)
st.markdown("---")
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.image('images/scale_1200.png', width=500)
# Загрузка модели
model, tokenizer = load_model()
# Параметры генерации
with st.sidebar:
st.header("Настройки генерации")
prompt = st.text_area("Введите начальный текст:", height=100)
max_length = st.slider("Максимальная длина:", 50, 500, 100)
num_return_sequences = st.slider("Число генераций:", 1, 5, 1)
st.subheader("Параметры выборки:")
sampling_method = st.radio("Метод:", ["Temperature", "Top-k & Top-p"])
if sampling_method == "Temperature":
temperature = st.slider("Temperature:", 0.1, 2.0, 1.0, 0.1)
top_k = None
top_p = None
else:
temperature = 1.0
top_k = st.slider("Top-k:", 1, 100, 50)
top_p = st.slider("Top-p:", 0.1, 1.0, 0.9, 0.05)
# Кнопка генерации
if st.sidebar.button("Сгенерировать текст"):
if not prompt:
st.warning("Введите начальный текст!")
return
gen_params = {
'max_length': max_length,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'num_return_sequences': num_return_sequences
}
with st.spinner("Прибухиваем..."):
generated = generate_text(model, tokenizer, prompt, gen_params)
st.markdown("---")
st.subheader("Результаты:")
for text in generated:
st.text_area(label="", value=text, height=200)
if __name__ == "__main__":
main()