import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline
from nltk.corpus import wordnet
import nltk
nltk.download('wordnet')


def generate_text(prompt, temperature=0.7, top_k=50, repetition_penalty=1.2, max_length=None, min_length=10):
    text_model = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(text_model)
    model = AutoModelForCausalLM.from_pretrained(text_model)

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0 if torch.cuda.is_available() else -1
    )

    return generator(
        prompt,
        max_length=max_length,
        min_length=min_length,
        do_sample=True,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
    )[0]["generated_text"]

def generate_image(prompt):
    image_model = "runwayml/stable-diffusion-v1-5"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = StableDiffusionPipeline.from_pretrained(image_model, torch_dtype=torch.float32)
    pipe = pipe.to(device)
    image = pipe(prompt).images[0]
    return image

def get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name().replace('_', ' '))
    return list(synonyms)

st.title(":black[_AI-Generated Blog Post_]")

title = st.text_input("Topic of the Article")


keywords_selection = st.selectbox('Do you want to select Keywords Manually or Automatic',['','Manually','Automatic'])

if keywords_selection == 'Manually' :
    keywords_input = st.text_input("Enter Some Keywords About The Topic (Separate keywords with commas)")
    keywords = [word.strip() for word in keywords_input.split(',')]
    keywords.append(title)

if keywords_selection == 'Automatic' :
    keywords = get_synonyms(title)
    st.write(f'Your keywords Are {keywords}')

try :
    if st.button('Generate Article'):
        if keywords:
            generated_text = " ".join(keywords)
            formatted_title = title.capitalize()

            st.markdown(
                f"<h1 style='text-align: center; color: blue; font-size: 70px;'>{formatted_title}</h1>",
                unsafe_allow_html=True
            )



            col1, col2, col3 = st.columns([1, 2, 1])
            with col2:
                generated_image1 = generate_image(generated_text)
                new_image1 = generated_image1.resize((700, 500))
                st.image(new_image1, use_column_width=True)


            # Introduction
            st.subheader("Introduction")
            intro_text = generate_text(f'introduction about : {generated_text}', min_length=100, max_length=200)
            intro_text = intro_text.replace(f"introduction about : {generated_text}", "")
            st.write(intro_text.strip())  # Display the generated introduction text

            modified_prompt = generated_text + 'bright'
            generated_image2 = generate_image(modified_prompt)

            new_image2 = generated_image2.resize((700, 300))
            st.image(new_image2, use_column_width=True)

            # Body 1
            col1, col2 = st.columns(2)
            with col1:
                st.subheader("Body")
                body_text1 = generate_text(f'article about : {generated_text}', min_length=100, max_length=150)
                body_text1 = body_text1.replace(f"article about : {generated_text}", "")
                st.write(body_text1.strip())  # Display the generated introduction text

            with col2:
                modified_prompt2 = generated_text + 'shade'
                generated_image3 = generate_image(modified_prompt2)
                st.markdown("<br><br><br><br>", unsafe_allow_html=True)
                st.image(generated_image3, use_column_width=True)

            # Body 2
            body_text2 = generate_text(f'article about : {generated_text}', min_length=200, max_length=300)
            body_text2 = body_text2.replace(f"{generated_text}", "")
            st.write(body_text2.strip())  # Display the generated introduction text

            modified_prompt3 = generated_text + title
            generated_image4 = generate_image(modified_prompt3)
            new_image3 = generated_image4.resize((700, 300))
            st.image(new_image3, use_column_width=True)

            # Conclusion
            st.subheader("Conclusion")
            conclusion_text = generate_text(f'conclusion about : {generated_text}', min_length=100, max_length=200)
            conclusion_text = conclusion_text.replace(f"conclusion about : {generated_text}", "")
            st.write(conclusion_text.strip())  # Display the generated introduction text

        else:
            st.warning("Please input keywords to generate content.")

except :
    st.warning('Please Enter Title and Keywords')


st.sidebar.title("Instructions")
st.sidebar.write(
    "1. Enter title and keywords related to the topic you want to generate content about."
    "\n2. Click 'Generate Article' to create the AI-generated blog post."
    "\n3. Explore the Introduction, Body, and Conclusion sections of the generated content."
)