pjq626's picture
Update app.py
e25c405 verified
import streamlit as st
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from gtts import gTTS
import io
# Load image captioning model
@st.cache_resource
def load_image_captioning_model():
"""
Load the BLIP image captioning model and processor from Hugging Face.
Returns:
processor: BLIP processor for image preprocessing.
model: BLIP model for generating captions.
"""
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model
# Load text generation model
@st.cache_resource
def load_text_generation_model():
"""
Load the GPT-2 text generation model and tokenizer from Hugging Face.
Returns:
model: GPT-2 model for generating text.
tokenizer: GPT-2 tokenizer for text preprocessing.
"""
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
return model, tokenizer
# Generate image caption
def generate_caption(image, processor, model):
"""
Generate a caption for the given image using the BLIP model.
Args:
image: PIL Image object.
processor: BLIP processor.
model: BLIP model.
Returns:
caption: Generated caption as a string.
"""
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, max_length=50)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
# Generate story from caption
def generate_story(caption, model, tokenizer):
"""
Generate a story based on the given caption using the GPT-2 model.
Args:
caption: Input caption as a string.
model: GPT-2 model.
tokenizer: GPT-2 tokenizer.
Returns:
story: Generated story as a string (100 words or less).
"""
# Add a child-friendly story prompt with clear instructions
story_prompt = (
f"Once upon a time, in a magical land, {caption}. "
"One sunny day, something magical happened! "
"The children were playing when they discovered something amazing. "
"Here is what happened next: "
)
input_ids = tokenizer.encode(story_prompt, return_tensors="pt")
output = model.generate(
input_ids,
max_length=100, # Limit story length to 100 words
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.9, # Add randomness for creativity
top_k=50, # Limit to top 50 likely words
top_p=0.95, # Use nucleus sampling for diverse outputs
)
story = tokenizer.decode(output[0], skip_special_tokens=True)
return story
# Convert text to speech using gTTS
def text_to_speech(text):
"""
Convert the given text to speech using gTTS.
Args:
text: Input text as a string.
Returns:
audio_bytes: Audio file as a BytesIO object.
"""
tts = gTTS(text=text, lang='en')
audio_bytes = io.BytesIO()
tts.write_to_fp(audio_bytes)
audio_bytes.seek(0)
return audio_bytes
# Main Streamlit app
def main():
"""
Main function to run the Streamlit application.
"""
# Page title and description
st.title("πŸ“– Storytelling Application for Kids")
st.markdown("""
Welcome to the Storytelling Application!
Upload an image, and we'll generate a fun story for you.
🎨 **Step 1:** Upload an image (JPG, JPEG, or PNG).
πŸ–ΌοΈ **Step 2:** Wait for the caption and story to be generated.
πŸ”Š **Step 3:** Listen to the story as audio!
""")
st.markdown("---")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Open and display the uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Load models
with st.spinner("Loading models..."):
caption_processor, caption_model = load_image_captioning_model()
story_model, story_tokenizer = load_text_generation_model()
# Generate caption
st.markdown("### πŸ–ΌοΈ Image Caption")
with st.spinner("Generating caption..."):
caption = generate_caption(image, caption_processor, caption_model)
st.success("Caption generated successfully!")
st.write(f"**{caption}**")
# Generate story
st.markdown("### πŸ“– Generated Story")
with st.spinner("Generating story..."):
story = generate_story(caption, story_model, story_tokenizer)
st.success("Story generated successfully!")
st.write(story)
# Convert story to speech
st.markdown("### πŸ”Š Story Audio")
with st.spinner("Converting story to audio..."):
audio = text_to_speech(story)
st.success("Audio generated successfully!")
st.audio(audio, format="audio/mp3")
except Exception as e:
st.error(f"An error occurred: {e}. Please upload a valid image file.")
else:
st.info("Please upload an image to get started.")
# Run the app
if __name__ == "__main__":
main()