Spaces:
Sleeping
Sleeping
import gradio as gr | |
# from langchain.llms import OpenAI as LangChainOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.chains import LLMChain | |
from langchain_openai import ChatOpenAI | |
import requests | |
import numpy as np | |
import os | |
from pathlib import Path | |
from openai import OpenAI | |
import tempfile | |
import shutil | |
from scipy.io import wavfile | |
from pydub import AudioSegment | |
# Initialize OpenAI clients | |
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini") | |
# llm = LangChainOpenAI(model_name="gpt-4o-mini") | |
openai_client = OpenAI() | |
# Create a prompt template for the story | |
story_template = """ | |
You are a creative story writer. Given a topic, write a short African-based story of about 100 words. | |
The story should incorporate themes, settings, or cultural elements related to the topic. Divide the story into 5 paragraphs. Each paragraph should be a distinct part of the story. | |
Topic: {topic} | |
Story: | |
""" | |
story_prompt = PromptTemplate.from_template(story_template) | |
story_chain = (RunnablePassthrough() | story_prompt | llm) | |
# Create a prompt template for the music based on the story prompt | |
music_template = """ | |
You are a music composer. Given a topic for a story, create a music prompt that would fit the potential mood and theme of such a story. | |
Provide a single line description including: | |
1. One or more instruments | |
2. The mood of the music | |
3. The energy level | |
4. The genre | |
Combine these elements into a single, fluid sentence without labels or prefixes. | |
Story Topic: {topic} | |
Music Prompt: | |
""" | |
music_prompt = PromptTemplate.from_template(music_template) | |
# Create LLMChains | |
# story_chain = LLMChain(llm=llm, prompt=story_prompt) | |
# music_chain = LLMChain(llm=llm, prompt=music_prompt) | |
music_chain = (RunnablePassthrough() | music_prompt | llm) | |
# Configuration for music generation | |
def get_config(): | |
return { | |
"TYPE": os.environ.get('TYPE', 'HF'), | |
"API_URL": os.environ.get('API_URL'), | |
"TOKEN": os.environ.get('ACCESS_TOKEN'), | |
"CUSTOM_URL": os.environ.get('CUSTOM_URL') | |
} | |
def get_headers(token): | |
return { | |
"Authorization": f"Bearer {token}", | |
"Content-Type": "application/json" | |
} | |
def query(payload, config): | |
try: | |
if config["TYPE"] == "CUSTOM": | |
url = config["CUSTOM_URL"] | |
response = requests.post(url, json=payload) | |
else: | |
headers = get_headers(config["TOKEN"]) | |
response = requests.post(config["API_URL"], headers=headers, json=payload) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
raise Exception(f"API request failed: {str(e)}") | |
def create_payload(prompt, duration, config): | |
if config["TYPE"] == "CUSTOM": | |
return { | |
"prompt": prompt, | |
"duration": int(duration) | |
} | |
else: | |
return { | |
"inputs": { | |
"prompt": prompt, | |
"duration": int(duration) | |
} | |
} | |
def process_audio(response): | |
try: | |
arr = np.array(response[0]['generated_audio']) | |
audio = arr.astype(np.float32) | |
sample_rate = response[0]['sample_rate'] | |
return audio, sample_rate | |
except (KeyError, IndexError, ValueError) as e: | |
raise Exception(f"Error processing audio: {str(e)}") | |
def generate_speech(text, filename): | |
try: | |
response = openai_client.audio.speech.create( | |
model="tts-1", | |
voice="alloy", | |
input=text | |
) | |
response.stream_to_file(filename) | |
return filename | |
except Exception as e: | |
print(f"Error generating speech: {e}") | |
return None | |
def save_wav(audio, sample_rate, filename): | |
try: | |
wavfile.write(filename, sample_rate, audio) | |
return filename | |
except Exception as e: | |
print(f"Error saving WAV file: {e}") | |
return None | |
def mix_audio(speech_file, music_file, output_file): | |
try: | |
# Load the audio files | |
speech = AudioSegment.from_mp3(speech_file) | |
music = AudioSegment.from_wav(music_file) | |
# Lower the volume of the music | |
music = music - 10 # Reduce volume by 10 dB | |
# Loop the music if it's shorter than the speech | |
if len(music) < len(speech): | |
music = music * (len(speech) // len(music) + 1) | |
# Trim the music to match the length of the speech | |
music = music[:len(speech)] | |
# Overlay the music onto the speech | |
mixed = speech.overlay(music, position=0) | |
# Export the mixed audio | |
mixed.export(output_file, format="mp3") | |
return output_file | |
except Exception as e: | |
print(f"Error mixing audio: {e}") | |
return None | |
def generate_story_and_music(topic): | |
with tempfile.TemporaryDirectory() as temp_dir: | |
try: | |
# Generate the story | |
story = story_chain.invoke({"topic":topic}).content | |
# Ensure the story has 5 paragraphs | |
paragraphs = story.split('\n\n') | |
if len(paragraphs) < 5: | |
paragraphs.extend([''] * (5 - len(paragraphs))) | |
elif len(paragraphs) > 5: | |
paragraphs = paragraphs[:5] | |
story = '\n\n'.join(paragraphs) | |
except Exception as e: | |
print(f"Error generating story: {e}") | |
story = "Failed to generate story." | |
try: | |
# Generate the music prompt based on the topic | |
final_music_prompt = music_chain.invoke({"topic":topic}).content.strip() | |
except Exception as e: | |
print(f"Error generating music prompt: {e}") | |
final_music_prompt = "Piano and strings creating a melancholic yet hopeful mood with moderate energy in a contemporary classical style" | |
# Generate music | |
config = get_config() | |
music_payload = create_payload(final_music_prompt, 30, config) # 30 seconds duration | |
try: | |
music_response = query(music_payload, config) | |
music_audio, music_sample_rate = process_audio(music_response) | |
# Save music as WAV file | |
music_file = os.path.join(temp_dir, "generated_music.wav") | |
save_wav(music_audio, music_sample_rate, music_file) | |
except Exception as e: | |
print(f"Error generating or saving music: {e}") | |
music_file = None | |
# Generate speech for the entire story | |
speech_file = os.path.join(temp_dir, "speech.mp3") | |
generated_file = generate_speech(story, speech_file) | |
if not generated_file: | |
speech_file = None | |
# Mix speech and music | |
if speech_file and music_file: | |
mixed_file = "mixed_story_with_music.mp3" | |
mixed_output = mix_audio(speech_file, music_file, mixed_file) | |
else: | |
mixed_output = None | |
return story, final_music_prompt, mixed_output | |
# Create the Gradio interface | |
def gradio_interface(topic): | |
story, music_prompt, mixed_audio = generate_story_and_music(topic) | |
outputs = [story, music_prompt, mixed_audio] | |
return outputs | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Textbox(lines=2, placeholder="Enter a topic for the story..."), | |
outputs=[ | |
gr.Textbox(label="Generated Story"), | |
gr.Textbox(label="Generated Music Prompt"), | |
gr.Audio(label="Story with Background Music", type="filepath") | |
], | |
title="Story Generator with Background Music", | |
description="Enter a topic, and the AI will generate a short story with matching background music. The story will be narrated with the music playing in the background." | |
) | |
# Launch the Gradio app | |
iface.launch() |