import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Загрузка модели и токенизатора @st.cache_resource def load_model(): model_name = "models/gpt" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return model, tokenizer def generate_text(model, tokenizer, prompt, gen_params): inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=gen_params['max_length'], temperature=gen_params['temperature'], top_k=gen_params['top_k'], top_p=gen_params['top_p'], num_return_sequences=gen_params['num_return_sequences'], do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated = [] for i, output in enumerate(outputs): text = tokenizer.decode(output, skip_special_tokens=True) generated.append(f"Генерация {i+1}:\n{text}\n{'-'*50}") return generated def main(): st.markdown( "

Генератор текста

", unsafe_allow_html=True ) st.markdown( "

(ну почти)

", unsafe_allow_html=True ) st.markdown("---") col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.image('images/scale_1200.png', width=500) # Загрузка модели model, tokenizer = load_model() # Параметры генерации with st.sidebar: st.header("Настройки генерации") prompt = st.text_area("Введите начальный текст:", height=100) max_length = st.slider("Максимальная длина:", 50, 500, 100) num_return_sequences = st.slider("Число генераций:", 1, 5, 1) st.subheader("Параметры выборки:") sampling_method = st.radio("Метод:", ["Temperature", "Top-k & Top-p"]) if sampling_method == "Temperature": temperature = st.slider("Temperature:", 0.1, 2.0, 1.0, 0.1) top_k = None top_p = None else: temperature = 1.0 top_k = st.slider("Top-k:", 1, 100, 50) top_p = st.slider("Top-p:", 0.1, 1.0, 0.9, 0.05) # Кнопка генерации if st.sidebar.button("Сгенерировать текст"): if not prompt: st.warning("Введите начальный текст!") return gen_params = { 'max_length': max_length, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'num_return_sequences': num_return_sequences } with st.spinner("Прибухиваем..."): generated = generate_text(model, tokenizer, prompt, gen_params) st.markdown("---") st.subheader("Результаты:") for text in generated: st.text_area(label="", value=text, height=200) if __name__ == "__main__": main()