stardate69's picture
Update app.py
c6d52c6 verified
from flask import Flask, request, jsonify, send_file
import spaces
import torch
import soundfile as sf
from huggingface_hub import login
from diffusers import StableAudioPipeline
import os
import io
import random
# Load Hugging Face token securely
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
if HUGGINGFACE_TOKEN is None:
raise ValueError("Missing Hugging Face token. Please set it in Hugging Face Secrets.")
login(HUGGINGFACE_TOKEN)
# Set device for PyTorch (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
# Load the StableAudio model from Hugging Face Hub
pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=torch_dtype)
pipe = pipe.to(device)
# Initialize Flask app
app = Flask(__name__)
# Route to generate audio
@app.route("/generate", methods=["GET"])
@spaces.GPU
def generate_audio():
prompt = request.args.get("prompt")
seed = request.args.get("seed", random.randint(0, 100000), type=int)
if not prompt:
return jsonify({"error": "Missing prompt parameter"}), 400
try:
# Load the StableAudio model from Hugging Face Hub
#pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=torch_dtype)
#pipe = pipe.to(device)
# Generate the audio using StableAudioPipeline
generator = torch.Generator(device)
generator.manual_seed(seed)
audio_output = pipe(
prompt=prompt,
negative_prompt='Low Quality',
num_inference_steps=10, # Number of diffusion steps
guidance_scale=14.0,
audio_end_in_s=1,
num_waveforms_per_prompt=1,
generator=generator
).audios
# Convert audio to BytesIO in memory
output_io = io.BytesIO()
output_io.truncate(0) # Clears any residual data from previous calls
output_audio = audio_output[0].T.float().cpu().numpy()
sf.write(output_io, output_audio, pipe.vae.sampling_rate, format="WAV") # Save as WAV or your preferred format
output_io.seek(0) # Reset buffer pointer to beginning
# Send the file in response as attachment for download
return send_file(output_io, as_attachment=False, download_name="output.wav", mimetype='audio/wav')
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the Flask app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)