import streamlit as st
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from gtts import gTTS
import io

# Load image captioning model
@st.cache_resource
def load_image_captioning_model():
    """
    Load the BLIP image captioning model and processor from Hugging Face.
    Returns:
        processor: BLIP processor for image preprocessing.
        model: BLIP model for generating captions.
    """
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    return processor, model

# Load text generation model
@st.cache_resource
def load_text_generation_model():
    """
    Load the GPT-2 text generation model and tokenizer from Hugging Face.
    Returns:
        model: GPT-2 model for generating text.
        tokenizer: GPT-2 tokenizer for text preprocessing.
    """
    model_name = "gpt2"
    model = GPT2LMHeadModel.from_pretrained(model_name)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    return model, tokenizer

# Generate image caption
def generate_caption(image, processor, model):
    """
    Generate a caption for the given image using the BLIP model.
    Args:
        image: PIL Image object.
        processor: BLIP processor.
        model: BLIP model.
    Returns:
        caption: Generated caption as a string.
    """
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs, max_length=50)
    caption = processor.decode(out[0], skip_special_tokens=True)
    return caption

# Generate story from caption
def generate_story(caption, model, tokenizer):
    """
    Generate a story based on the given caption using the GPT-2 model.
    Args:
        caption: Input caption as a string.
        model: GPT-2 model.
        tokenizer: GPT-2 tokenizer.
    Returns:
        story: Generated story as a string (100 words or less).
    """
    # Add a child-friendly story prompt with clear instructions
    story_prompt = (
        f"Once upon a time, in a magical land, {caption}. "
        "One sunny day, something magical happened! "
        "The children were playing when they discovered something amazing. "
        "Here is what happened next: "
    )
    input_ids = tokenizer.encode(story_prompt, return_tensors="pt")
    output = model.generate(
        input_ids,
        max_length=100,  # Limit story length to 100 words
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        temperature=0.9,  # Add randomness for creativity
        top_k=50,  # Limit to top 50 likely words
        top_p=0.95,  # Use nucleus sampling for diverse outputs
    )
    story = tokenizer.decode(output[0], skip_special_tokens=True)
    return story

# Convert text to speech using gTTS
def text_to_speech(text):
    """
    Convert the given text to speech using gTTS.
    Args:
        text: Input text as a string.
    Returns:
        audio_bytes: Audio file as a BytesIO object.
    """
    tts = gTTS(text=text, lang='en')
    audio_bytes = io.BytesIO()
    tts.write_to_fp(audio_bytes)
    audio_bytes.seek(0)
    return audio_bytes

# Main Streamlit app
def main():
    """
    Main function to run the Streamlit application.
    """
    # Page title and description
    st.title("📖 Storytelling Application for Kids")
    st.markdown("""
        Welcome to the Storytelling Application!  
        Upload an image, and we'll generate a fun story for you.  
        🎨 **Step 1:** Upload an image (JPG, JPEG, or PNG).  
        🖼️ **Step 2:** Wait for the caption and story to be generated.  
        🔊 **Step 3:** Listen to the story as audio!
    """)
    st.markdown("---")

    # Upload image
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    if uploaded_file is not None:
        try:
            # Open and display the uploaded image
            image = Image.open(uploaded_file).convert("RGB")
            st.image(image, caption="Uploaded Image", use_container_width=True)

            # Load models
            with st.spinner("Loading models..."):
                caption_processor, caption_model = load_image_captioning_model()
                story_model, story_tokenizer = load_text_generation_model()

            # Generate caption
            st.markdown("### 🖼️ Image Caption")
            with st.spinner("Generating caption..."):
                caption = generate_caption(image, caption_processor, caption_model)
                st.success("Caption generated successfully!")
                st.write(f"**{caption}**")

            # Generate story
            st.markdown("### 📖 Generated Story")
            with st.spinner("Generating story..."):
                story = generate_story(caption, story_model, story_tokenizer)
                st.success("Story generated successfully!")
                st.write(story)

            # Convert story to speech
            st.markdown("### 🔊 Story Audio")
            with st.spinner("Converting story to audio..."):
                audio = text_to_speech(story)
                st.success("Audio generated successfully!")
                st.audio(audio, format="audio/mp3")

        except Exception as e:
            st.error(f"An error occurred: {e}. Please upload a valid image file.")
    else:
        st.info("Please upload an image to get started.")

# Run the app
if __name__ == "__main__":
    main()