# Copyright 2025 Google LLC. # Based on work by Yousif Ahmed. # Concept: ChronoWeave – Branching Narrative Generation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at: https://www.apache.org/licenses/LICENSE-2.0 import streamlit as st import google.generativeai as genai import os import json import numpy as np from io import BytesIO import time import wave import contextlib import asyncio import uuid # For unique identifiers import shutil # For directory operations import logging # Image handling from PIL import Image # Pydantic for data validation from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator from typing import List, Optional, Dict, Any # Video and audio processing from moviepy.editor import ImageClip, AudioFileClip, concatenate_videoclips # Type hints import typing_extensions as typing # Async support import nest_asyncio nest_asyncio.apply() # Import Vertex AI SDK and Google credentials support import vertexai from vertexai.preview.vision_models import ImageGenerationModel from google.oauth2 import service_account # Import gTTS for audio generation from gtts import gTTS # --- Logging Setup --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # --- App Configuration --- st.set_page_config(page_title="ChronoWeave", layout="wide", initial_sidebar_state="expanded") st.title("🌀 ChronoWeave: Advanced Branching Narrative Generator") st.markdown(""" Generate multiple, branching story timelines from a single theme using AI, complete with images and narration. *Based on the work by Yousif Ahmed. Copyright 2025 Google LLC.* """) # --- Constants --- TEXT_MODEL_ID = "models/gemini-1.5-flash" AUDIO_MODEL_ID = "models/gemini-1.5-flash" AUDIO_SAMPLING_RATE = 24000 IMAGE_MODEL_ID = "imagen-3.0-generate-002" # Vertex AI Imagen model identifier DEFAULT_ASPECT_RATIO = "1:1" VIDEO_FPS = 24 VIDEO_CODEC = "libx264" AUDIO_CODEC = "aac" TEMP_DIR_BASE = ".chrono_temp" # --- Secrets and Environment Variables --- # Load GOOGLE_API_KEY try: GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"] logger.info("Google API Key loaded from Streamlit secrets.") except KeyError: GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') if GOOGLE_API_KEY: logger.info("Google API Key loaded from environment variable.") else: st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨") st.stop() # Load PROJECT_ID and LOCATION PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID") LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1") if not PROJECT_ID: st.error("🚨 **PROJECT_ID not set!** Please add PROJECT_ID to your secrets.", icon="🚨") st.stop() # Load and verify SERVICE_ACCOUNT_JSON service_account_json = os.environ.get("SERVICE_ACCOUNT_JSON", "").strip() if not service_account_json: st.error("🚨 **SERVICE_ACCOUNT_JSON is missing or empty!** Please add your service account JSON to your secrets.", icon="🚨") st.stop() try: service_account_info = json.loads(service_account_json) credentials = service_account.Credentials.from_service_account_info(service_account_info) logger.info("Service account credentials loaded successfully.") except Exception as e: st.error(f"🚨 Failed to load service account JSON: {e}", icon="🚨") st.stop() # Initialize Vertex AI with service account credentials vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=credentials) # --- Initialize Google Clients for Text and Audio --- try: genai.configure(api_key=GOOGLE_API_KEY) logger.info("Configured google-generativeai with API key.") client_standard = genai.GenerativeModel(TEXT_MODEL_ID) logger.info(f"Initialized text/JSON model handle: {TEXT_MODEL_ID}.") live_model = genai.GenerativeModel(AUDIO_MODEL_ID) logger.info(f"Initialized audio model handle: {AUDIO_MODEL_ID}.") except AttributeError as ae: logger.exception("AttributeError during Client Init.") st.error(f"🚨 Init Error: {ae}. Update library?", icon="🚨") st.stop() except Exception as e: logger.exception("Failed to initialize Google Clients/Models.") st.error(f"🚨 Failed Init: {e}", icon="🚨") st.stop() # --- Define Pydantic Schemas --- class StorySegment(BaseModel): scene_id: int = Field(..., ge=0) image_prompt: str = Field(..., min_length=10, max_length=250) audio_text: str = Field(..., min_length=5, max_length=150) character_description: str = Field(..., max_length=250) timeline_visual_modifier: Optional[str] = Field(None, max_length=50) @field_validator('image_prompt') @classmethod def image_prompt_no_humans(cls, v: str) -> str: if any(w in v.lower() for w in ["person", "people", "human", "man", "woman", "boy", "girl", "child"]): logger.warning(f"Prompt '{v[:50]}...' may contain humans.") return v class Timeline(BaseModel): timeline_id: int = Field(..., ge=0) divergence_reason: str = Field(..., min_length=5) segments: List[StorySegment] = Field(..., min_items=1) class ChronoWeaveResponse(BaseModel): core_theme: str = Field(..., min_length=5) timelines: List[Timeline] = Field(..., min_items=1) total_scenes_per_timeline: int = Field(..., gt=0) @model_validator(mode='after') def check_timeline_segment_count(self) -> 'ChronoWeaveResponse': expected = self.total_scenes_per_timeline for i, t in enumerate(self.timelines): if len(t.segments) != expected: raise ValueError(f"Timeline {i} ID {t.timeline_id}: Expected {expected}, found {len(t.segments)}.") return self # --- Helper Functions --- @contextlib.contextmanager def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLING_RATE, sample_width: int = 2): """Context manager to safely write WAV files.""" wf = None try: wf = wave.open(filename, "wb") wf.setnchannels(channels) wf.setsampwidth(sample_width) wf.setframerate(rate) yield wf except Exception as e: logger.error(f"Error opening/configuring wave file {filename}: {e}") raise finally: if wf: try: wf.close() except Exception as e_close: logger.error(f"Error closing wave file {filename}: {e_close}") # --- Audio Generation using gTTS --- async def generate_audio_live_async(api_text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]: """ Generates audio using gTTS (Google Text-to-Speech). Saves an MP3 file; MoviePy supports MP3 playback. """ task_id = os.path.basename(output_filename).split('.')[0] logger.info(f"🎙️ [{task_id}] Generating audio via gTTS for text: '{api_text[:60]}...'") try: tts = gTTS(text=api_text, lang="en") mp3_filename = output_filename.replace(".wav", ".mp3") tts.save(mp3_filename) logger.info(f"✅ [{task_id}] Audio saved: {os.path.basename(mp3_filename)}") return mp3_filename except Exception as e: error_str = str(e) if "429" in error_str: st.error(f"Audio generation for {task_id} failed: 429 Too Many Requests from TTS API. Please try again later.", icon="🔊") else: st.error(f"Audio generation for {task_id} failed: {e}", icon="🔊") logger.exception(f"❌ [{task_id}] Audio generation error: {e}") return None def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]: """Generates branching story sequences using Gemini structured output and validates with Pydantic.""" st.info(f"📚 Generating {num_timelines} timeline(s) x {num_scenes} scenes for: '{theme}'...") logger.info(f"Requesting story structure: Theme='{theme}', Timelines={num_timelines}, Scenes={num_scenes}") divergence_instruction = ( f"Introduce clear points of divergence between timelines, after first scene if possible. " f"Hint: '{divergence_prompt}'. State divergence reason clearly. **For timeline_id 0, use 'Initial path' or 'Baseline scenario'.**" ) prompt = f"""Act as narrative designer. Create story for theme: "{theme}". Instructions: 1. Exactly **{num_timelines}** timelines. 2. Each timeline exactly **{num_scenes}** scenes. 3. **NO humans/humanoids**; focus on animals, fantasy creatures, animated objects, nature. 4. {divergence_instruction}. 5. Style: **'Simple, friendly kids animation, bright colors, rounded shapes'**, unless `timeline_visual_modifier` alters. 6. `audio_text`: single concise sentence (max 30 words). 7. `image_prompt`: descriptive, concise (target 15-35 words MAX). Focus on scene elements. **AVOID repeating general style**. 8. `character_description`: VERY brief (name, features). Target < 20 words. Output: ONLY valid JSON object adhering to schema. No text before/after. JSON Schema: ```json {json.dumps(ChronoWeaveResponse.model_json_schema(), indent=2)} ```""" try: response = client_standard.generate_content( contents=prompt, generation_config=genai.types.GenerationConfig( response_mime_type="application/json", temperature=0.7 ) ) try: raw_data = json.loads(response.text) except json.JSONDecodeError as json_err: logger.error(f"Failed JSON decode: {json_err}\nResponse:\n{response.text}") st.error(f"🚨 Failed parse story: {json_err}", icon="📄") st.text_area("Problem Response:", response.text, height=150) return None except Exception as e: logger.error(f"Error processing text: {e}") st.error(f"🚨 Error processing AI response: {e}", icon="📄") return None try: validated_data = ChronoWeaveResponse.model_validate(raw_data) logger.info("✅ Story structure OK!") st.success("✅ Story structure OK!") return validated_data except ValidationError as val_err: logger.error(f"JSON validation failed: {val_err}\nData:\n{json.dumps(raw_data, indent=2)}") st.error(f"🚨 Gen structure invalid: {val_err}", icon="🧬") st.json(raw_data) return None except genai.types.generation_types.BlockedPromptException as bpe: logger.error(f"Story gen blocked: {bpe}") st.error("🚨 Story prompt blocked.", icon="🚫") return None except Exception as e: logger.exception("Error during story gen:") st.error(f"🚨 Story gen error: {e}", icon="💥") return None def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str = "IMG") -> Optional[Image.Image]: """ Generates an image using Vertex AI's Imagen model via the Vertex AI preview API. Loads the pretrained Imagen model and attempts to generate an image. If a quota exceeded error occurs, it advises you to request a quota increase. """ logger.info(f"🖼️ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})") try: generation_model = ImageGenerationModel.from_pretrained(IMAGE_MODEL_ID) images = generation_model.generate_images( prompt=prompt, number_of_images=1, aspect_ratio=aspect_ratio, negative_prompt="", person_generation="", safety_filter_level="", add_watermark=True, ) image = images[0]._pil_image logger.info(f"✅ [{task_id}] Image generated successfully.") return image except Exception as e: error_str = str(e) if "Quota exceeded" in error_str: error_msg = ( "Quota exceeded for image generation requests. " "Please submit a quota increase request via the Vertex AI console: https://cloud.google.com/vertex-ai/docs/generative-ai/quotas-genai" ) else: error_msg = f"Image generation for {task_id} failed: {e}" logger.exception(f"❌ [{task_id}] {error_msg}") st.error(error_msg, icon="🖼️") return None # --- Streamlit UI Elements --- st.sidebar.header("⚙️ Configuration") if GOOGLE_API_KEY: st.sidebar.success("Google API Key Loaded", icon="✅") else: st.sidebar.error("Google API Key Missing!", icon="🚨") theme = st.sidebar.text_input("📖 Story Theme:", "A curious squirrel finds a mysterious, glowing acorn") num_scenes = st.sidebar.slider("🎬 Scenes per Timeline:", min_value=2, max_value=7, value=3) num_timelines = st.sidebar.slider("🌿 Number of Timelines:", min_value=1, max_value=4, value=2) divergence_prompt = st.sidebar.text_input("↔️ Divergence Hint (Optional):", placeholder="e.g., What if a bird tried to steal it?") st.sidebar.subheader("🎨 Visual & Audio Settings") aspect_ratio = st.sidebar.selectbox("🖼️ Image Aspect Ratio:", ["1:1", "16:9", "9:16"], index=0) audio_voice = None generate_button = st.sidebar.button("✨ Generate ChronoWeave ✨", type="primary", disabled=(not GOOGLE_API_KEY), use_container_width=True) st.sidebar.markdown("---") st.sidebar.info("⏳ Generation can take minutes.") st.sidebar.markdown(f"Txt:{TEXT_MODEL_ID}, Img:{IMAGE_MODEL_ID}, Aud:{AUDIO_MODEL_ID}", unsafe_allow_html=True) # --- Main Logic --- if generate_button: if not theme: st.error("Please enter a story theme.", icon="👈") else: run_id = str(uuid.uuid4()).split('-')[0] temp_dir = os.path.join(TEMP_DIR_BASE, f"run_{run_id}") try: os.makedirs(temp_dir, exist_ok=True) logger.info(f"Created temp dir: {temp_dir}") except OSError as e: st.error(f"🚨 Failed to create temp dir {temp_dir}: {e}", icon="📂") st.stop() final_video_paths, generation_errors = {}, {} chrono_response: Optional[ChronoWeaveResponse] = None with st.spinner("Generating narrative structure... 🤔"): chrono_response = generate_story_sequence_chrono(theme, num_scenes, num_timelines, divergence_prompt) if chrono_response: overall_start_time = time.time() all_timelines_successful = True with st.status("Generating assets and composing videos...", expanded=True) as status: for timeline_index, timeline in enumerate(chrono_response.timelines): timeline_id, divergence, segments = timeline.timeline_id, timeline.divergence_reason, timeline.segments timeline_label = f"Timeline {timeline_id}" st.subheader(f"Processing {timeline_label}: {divergence}") logger.info(f"--- Processing {timeline_label} (Idx: {timeline_index}) ---") generation_errors[timeline_id] = [] temp_image_files, temp_audio_files, video_clips = {}, {}, [] timeline_start_time = time.time() scene_success_count = 0 for scene_index, segment in enumerate(segments): scene_id = segment.scene_id task_id = f"T{timeline_id}_S{scene_id}" status.update(label=f"Processing {timeline_label}, Scene {scene_id + 1}/{len(segments)}...") st.markdown(f"--- **Scene {scene_id + 1} ({task_id})** ---") logger.info(f"Processing {timeline_label}, Scene {scene_id + 1}/{len(segments)}...") scene_has_error = False st.write(f"*Img Prompt:* {segment.image_prompt}" + (f" *(Mod: {segment.timeline_visual_modifier})*" if segment.timeline_visual_modifier else "")) st.write(f"*Audio Text:* {segment.audio_text}") # --- 2a. Image Generation --- generated_image: Optional[Image.Image] = None with st.spinner(f"[{task_id}] Generating image... 🎨"): combined_prompt = segment.image_prompt if segment.character_description: combined_prompt += f" Featuring: {segment.character_description}" if segment.timeline_visual_modifier: combined_prompt += f" Style hint: {segment.timeline_visual_modifier}." generated_image = generate_image_imagen(combined_prompt, aspect_ratio, task_id) if generated_image: image_path = os.path.join(temp_dir, f"{task_id}_image.png") try: generated_image.save(image_path) temp_image_files[scene_id] = image_path st.image(generated_image, width=180, caption=f"Scene {scene_id + 1}") except Exception as e: logger.error(f"❌ [{task_id}] Img save error: {e}") st.error(f"Save image {task_id} failed.", icon="💾") scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Img save fail.") else: scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Img gen fail.") continue # --- 2b. Audio Generation --- generated_audio_path: Optional[str] = None if not scene_has_error: with st.spinner(f"[{task_id}] Generating audio... 🔊"): audio_path_temp = os.path.join(temp_dir, f"{task_id}_audio.wav") try: generated_audio_path = asyncio.run(generate_audio_live_async(segment.audio_text, audio_path_temp, audio_voice)) except RuntimeError as e: logger.error(f"❌ [{task_id}] Asyncio error: {e}") st.error(f"Asyncio audio error {task_id}: {e}", icon="⚡") scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Audio async err.") except Exception as e: logger.exception(f"❌ [{task_id}] Audio error: {e}") st.error(f"Audio error {task_id}: {e}", icon="💥") scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Audio gen err.") if generated_audio_path: temp_audio_files[scene_id] = generated_audio_path try: with open(generated_audio_path, 'rb') as ap: st.audio(ap.read(), format='audio/mp3') except Exception as e: logger.warning(f"⚠️ [{task_id}] Audio preview error: {e}") else: scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Audio gen fail.") continue # --- 2c. Create Video Clip --- if not scene_has_error and scene_id in temp_image_files and scene_id in temp_audio_files: st.write(f"🎬 Creating clip S{scene_id + 1}...") img_path, aud_path = temp_image_files[scene_id], temp_audio_files[scene_id] audio_clip_instance, image_clip_instance, composite_clip = None, None, None try: if not os.path.exists(img_path): raise FileNotFoundError(f"Img missing: {img_path}") if not os.path.exists(aud_path): raise FileNotFoundError(f"Aud missing: {aud_path}") audio_clip_instance = AudioFileClip(aud_path) np_image = np.array(Image.open(img_path)) image_clip_instance = ImageClip(np_image).set_duration(audio_clip_instance.duration) composite_clip = image_clip_instance.set_audio(audio_clip_instance) video_clips.append(composite_clip) logger.info(f"✅ [{task_id}] Clip created (Dur: {audio_clip_instance.duration:.2f}s).") st.write(f"✅ Clip created (Dur: {audio_clip_instance.duration:.2f}s).") scene_success_count += 1 except Exception as e: logger.exception(f"❌ [{task_id}] Failed clip creation: {e}") st.error(f"Failed clip {task_id}: {e}", icon="🎬") scene_has_error = True generation_errors[timeline_id].append(f"S{scene_id + 1}: Clip fail.") finally: if audio_clip_instance: audio_clip_instance.close() if image_clip_instance: image_clip_instance.close() # --- 2d. Assemble Timeline Video --- timeline_duration = time.time() - timeline_start_time if video_clips and scene_success_count == len(segments): status.update(label=f"Composing video {timeline_label}...") st.write(f"🎞️ Assembling video {timeline_label}...") logger.info(f"🎞️ Assembling video {timeline_label}...") output_filename = os.path.join(temp_dir, f"timeline_{timeline_id}_final.mp4") final_timeline_video = None try: final_timeline_video = concatenate_videoclips(video_clips, method="compose") final_timeline_video.write_videofile( output_filename, fps=VIDEO_FPS, codec=VIDEO_CODEC, audio_codec=AUDIO_CODEC, logger=None ) final_video_paths[timeline_id] = output_filename logger.info(f"✅ [{timeline_label}] Video saved: {os.path.basename(output_filename)}") st.success(f"✅ Video {timeline_label} completed in {timeline_duration:.2f}s.") except Exception as e: logger.exception(f"❌ [{timeline_label}] Video assembly failed: {e}") st.error(f"Assemble video {timeline_label} failed: {e}", icon="📼") all_timelines_successful = False generation_errors[timeline_id].append(f"T{timeline_id}: Assembly fail.") finally: logger.debug(f"[{timeline_label}] Closing {len(video_clips)} clips...") for i, clip in enumerate(video_clips): try: clip.close() except Exception as e_close: logger.warning(f"⚠️ [{timeline_label}] Clip close err {i}: {e_close}") if final_timeline_video: try: final_timeline_video.close() except Exception as e_close_final: logger.warning(f"⚠️ [{timeline_label}] Final vid close err: {e_close_final}") elif not video_clips: logger.warning(f"[{timeline_label}] No clips. Skip assembly.") st.warning(f"No scenes for {timeline_label}. No video.", icon="🚫") all_timelines_successful = False else: error_count = len(generation_errors[timeline_id]) logger.warning(f"[{timeline_label}] {error_count} scene err(s). Skip assembly.") st.warning(f"{timeline_label}: {error_count} err(s). Video not assembled.", icon="⚠️") all_timelines_successful = False if generation_errors[timeline_id]: logger.error(f"Errors {timeline_label}: {generation_errors[timeline_id]}") # --- End of Timelines Loop --- overall_duration = time.time() - overall_start_time if all_timelines_successful and final_video_paths: status_msg = f"Complete! ({len(final_video_paths)} videos in {overall_duration:.2f}s)" status.update(label=status_msg, state="complete", expanded=False) logger.info(status_msg) elif final_video_paths: status_msg = f"Partially Complete ({len(final_video_paths)} videos, errors). {overall_duration:.2f}s" status.update(label=status_msg, state="warning", expanded=True) logger.warning(status_msg) else: status_msg = f"Failed. No videos. {overall_duration:.2f}s" status.update(label=status_msg, state="error", expanded=True) logger.error(status_msg) # --- 3. Display Results --- st.header("🎬 Generated Timelines") if final_video_paths: sorted_timeline_ids = sorted(final_video_paths.keys()) num_cols = min(len(sorted_timeline_ids), 3) cols = st.columns(num_cols) for idx, timeline_id in enumerate(sorted_timeline_ids): col = cols[idx % num_cols] video_path = final_video_paths[timeline_id] timeline_data = next((t for t in chrono_response.timelines if t.timeline_id == timeline_id), None) reason = timeline_data.divergence_reason if timeline_data else "Unknown" with col: st.subheader(f"Timeline {timeline_id}") st.caption(f"Divergence: {reason}") try: with open(video_path, 'rb') as vf: video_bytes = vf.read() st.video(video_bytes) logger.info(f"Displaying T{timeline_id}") st.download_button(f"Download T{timeline_id}", video_bytes, f"timeline_{timeline_id}.mp4", "video/mp4", key=f"dl_{timeline_id}") if generation_errors.get(timeline_id): scene_errors = [err for err in generation_errors[timeline_id] if not err.startswith(f"T{timeline_id}:")] if scene_errors: with st.expander(f"⚠️ View {len(scene_errors)} Scene Issues"): for err in scene_errors: st.warning(f"- {err}") except FileNotFoundError: logger.error(f"Video missing: {video_path}") st.error(f"Error: Video missing T{timeline_id}.", icon="🚨") except Exception as e: logger.exception(f"Display error {video_path}: {e}") st.error(f"Display error T{timeline_id}: {e}", icon="🚨") else: st.warning("No final videos were successfully generated.") st.subheader("Summary of Generation Issues") has_errors = any(generation_errors.values()) if has_errors: with st.expander("View All Errors", expanded=True): for tid, errors in generation_errors.items(): if errors: st.error(f"**Timeline {tid}:**") for msg in errors: st.error(f" - {msg}") else: st.info("No generation errors recorded.") # --- 4. Cleanup --- st.info(f"Attempting cleanup: {temp_dir}") try: shutil.rmtree(temp_dir) logger.info(f"✅ Temp dir removed: {temp_dir}") st.success("✅ Temp files cleaned.") except Exception as e: logger.error(f"⚠️ Failed to remove temp dir {temp_dir}: {e}") st.warning(f"Could not remove temp files: {temp_dir}.", icon="⚠️") elif not chrono_response: logger.error("Story gen/validation failed.") else: st.error("Unexpected issue post-gen.", icon="🛑") logger.error("Chrono_response truthy but invalid.") else: st.info("Configure settings and click '✨ Generate ChronoWeave ✨' to start.")