from flask import Flask, request, jsonify, send_file import spaces import torch import soundfile as sf from huggingface_hub import login from diffusers import StableAudioPipeline import os import io import random # Load Hugging Face token securely HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") if HUGGINGFACE_TOKEN is None: raise ValueError("Missing Hugging Face token. Please set it in Hugging Face Secrets.") login(HUGGINGFACE_TOKEN) # Set device for PyTorch (GPU or CPU) device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 # Load the StableAudio model from Hugging Face Hub pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=torch_dtype) pipe = pipe.to(device) # Initialize Flask app app = Flask(__name__) # Route to generate audio @app.route("/generate", methods=["GET"]) @spaces.GPU def generate_audio(): prompt = request.args.get("prompt") seed = request.args.get("seed", random.randint(0, 100000), type=int) if not prompt: return jsonify({"error": "Missing prompt parameter"}), 400 try: # Load the StableAudio model from Hugging Face Hub #pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=torch_dtype) #pipe = pipe.to(device) # Generate the audio using StableAudioPipeline generator = torch.Generator(device) generator.manual_seed(seed) audio_output = pipe( prompt=prompt, negative_prompt='Low Quality', num_inference_steps=10, # Number of diffusion steps guidance_scale=14.0, audio_end_in_s=1, num_waveforms_per_prompt=1, generator=generator ).audios # Convert audio to BytesIO in memory output_io = io.BytesIO() output_io.truncate(0) # Clears any residual data from previous calls output_audio = audio_output[0].T.float().cpu().numpy() sf.write(output_io, output_audio, pipe.vae.sampling_rate, format="WAV") # Save as WAV or your preferred format output_io.seek(0) # Reset buffer pointer to beginning # Send the file in response as attachment for download return send_file(output_io, as_attachment=False, download_name="output.wav", mimetype='audio/wav') except Exception as e: return jsonify({"error": str(e)}), 500 # Run the Flask app if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)