story-music / app.py
ageraustine's picture
Update app.py
a54404b verified
import gradio as gr
# from langchain.llms import OpenAI as LangChainOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI
import requests
import numpy as np
import os
from pathlib import Path
from openai import OpenAI
import tempfile
import shutil
from scipy.io import wavfile
from pydub import AudioSegment
# Initialize OpenAI clients
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")
# llm = LangChainOpenAI(model_name="gpt-4o-mini")
openai_client = OpenAI()
# Create a prompt template for the story
story_template = """
You are a creative story writer. Given a topic, write a short African-based story of about 100 words.
The story should incorporate themes, settings, or cultural elements related to the topic. Divide the story into 5 paragraphs. Each paragraph should be a distinct part of the story.
Topic: {topic}
Story:
"""
story_prompt = PromptTemplate.from_template(story_template)
story_chain = (RunnablePassthrough() | story_prompt | llm)
# Create a prompt template for the music based on the story prompt
music_template = """
You are a music composer. Given a topic for a story, create a music prompt that would fit the potential mood and theme of such a story.
Provide a single line description including:
1. One or more instruments
2. The mood of the music
3. The energy level
4. The genre
Combine these elements into a single, fluid sentence without labels or prefixes.
Story Topic: {topic}
Music Prompt:
"""
music_prompt = PromptTemplate.from_template(music_template)
# Create LLMChains
# story_chain = LLMChain(llm=llm, prompt=story_prompt)
# music_chain = LLMChain(llm=llm, prompt=music_prompt)
music_chain = (RunnablePassthrough() | music_prompt | llm)
# Configuration for music generation
def get_config():
return {
"TYPE": os.environ.get('TYPE', 'HF'),
"API_URL": os.environ.get('API_URL'),
"TOKEN": os.environ.get('ACCESS_TOKEN'),
"CUSTOM_URL": os.environ.get('CUSTOM_URL')
}
def get_headers(token):
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
def query(payload, config):
try:
if config["TYPE"] == "CUSTOM":
url = config["CUSTOM_URL"]
response = requests.post(url, json=payload)
else:
headers = get_headers(config["TOKEN"])
response = requests.post(config["API_URL"], headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise Exception(f"API request failed: {str(e)}")
def create_payload(prompt, duration, config):
if config["TYPE"] == "CUSTOM":
return {
"prompt": prompt,
"duration": int(duration)
}
else:
return {
"inputs": {
"prompt": prompt,
"duration": int(duration)
}
}
def process_audio(response):
try:
arr = np.array(response[0]['generated_audio'])
audio = arr.astype(np.float32)
sample_rate = response[0]['sample_rate']
return audio, sample_rate
except (KeyError, IndexError, ValueError) as e:
raise Exception(f"Error processing audio: {str(e)}")
def generate_speech(text, filename):
try:
response = openai_client.audio.speech.create(
model="tts-1",
voice="alloy",
input=text
)
response.stream_to_file(filename)
return filename
except Exception as e:
print(f"Error generating speech: {e}")
return None
def save_wav(audio, sample_rate, filename):
try:
wavfile.write(filename, sample_rate, audio)
return filename
except Exception as e:
print(f"Error saving WAV file: {e}")
return None
def mix_audio(speech_file, music_file, output_file):
try:
# Load the audio files
speech = AudioSegment.from_mp3(speech_file)
music = AudioSegment.from_wav(music_file)
# Lower the volume of the music
music = music - 10 # Reduce volume by 10 dB
# Loop the music if it's shorter than the speech
if len(music) < len(speech):
music = music * (len(speech) // len(music) + 1)
# Trim the music to match the length of the speech
music = music[:len(speech)]
# Overlay the music onto the speech
mixed = speech.overlay(music, position=0)
# Export the mixed audio
mixed.export(output_file, format="mp3")
return output_file
except Exception as e:
print(f"Error mixing audio: {e}")
return None
def generate_story_and_music(topic):
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Generate the story
story = story_chain.invoke({"topic":topic}).content
# Ensure the story has 5 paragraphs
paragraphs = story.split('\n\n')
if len(paragraphs) < 5:
paragraphs.extend([''] * (5 - len(paragraphs)))
elif len(paragraphs) > 5:
paragraphs = paragraphs[:5]
story = '\n\n'.join(paragraphs)
except Exception as e:
print(f"Error generating story: {e}")
story = "Failed to generate story."
try:
# Generate the music prompt based on the topic
final_music_prompt = music_chain.invoke({"topic":topic}).content.strip()
except Exception as e:
print(f"Error generating music prompt: {e}")
final_music_prompt = "Piano and strings creating a melancholic yet hopeful mood with moderate energy in a contemporary classical style"
# Generate music
config = get_config()
music_payload = create_payload(final_music_prompt, 30, config) # 30 seconds duration
try:
music_response = query(music_payload, config)
music_audio, music_sample_rate = process_audio(music_response)
# Save music as WAV file
music_file = os.path.join(temp_dir, "generated_music.wav")
save_wav(music_audio, music_sample_rate, music_file)
except Exception as e:
print(f"Error generating or saving music: {e}")
music_file = None
# Generate speech for the entire story
speech_file = os.path.join(temp_dir, "speech.mp3")
generated_file = generate_speech(story, speech_file)
if not generated_file:
speech_file = None
# Mix speech and music
if speech_file and music_file:
mixed_file = "mixed_story_with_music.mp3"
mixed_output = mix_audio(speech_file, music_file, mixed_file)
else:
mixed_output = None
return story, final_music_prompt, mixed_output
# Create the Gradio interface
def gradio_interface(topic):
story, music_prompt, mixed_audio = generate_story_and_music(topic)
outputs = [story, music_prompt, mixed_audio]
return outputs
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(lines=2, placeholder="Enter a topic for the story..."),
outputs=[
gr.Textbox(label="Generated Story"),
gr.Textbox(label="Generated Music Prompt"),
gr.Audio(label="Story with Background Music", type="filepath")
],
title="Story Generator with Background Music",
description="Enter a topic, and the AI will generate a short story with matching background music. The story will be narrated with the music playing in the background."
)
# Launch the Gradio app
iface.launch()