|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
from transformers import CLIPImageProcessor |
|
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
import torch |
|
import io |
|
|
|
myapp = Flask(__name__) |
|
CORS(myapp) |
|
|
|
|
|
repo_id = "stabilityai/stable-diffusion-2" |
|
pipe = DiffusionPipeline.from_pretrained( |
|
repo_id, |
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), |
|
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
pipe = pipe.to("cpu") |
|
|
|
@myapp.route('/') |
|
def home(): |
|
return "Stable Diffusion API is running!" |
|
|
|
@myapp.route('/generate', methods=['POST']) |
|
def generate(): |
|
prompt = request.form.get('prompt') |
|
if not prompt: |
|
return jsonify({"error": "No prompt provided!"}), 400 |
|
|
|
|
|
results = pipe(prompt, guidance_scale=9, num_inference_steps=25, num_images_per_prompt=1) |
|
|
|
|
|
if not results.nsfw_content_detected[0]: |
|
img_io = io.BytesIO() |
|
results.images[0].save(img_io, format='PNG') |
|
img_io.seek(0) |
|
return send_file(img_io, mimetype='image/png', as_attachment=True, attachment_filename='generated_image.png') |
|
else: |
|
return jsonify({"error": "NSFW content detected!"}), 400 |
|
|
|
if __name__ == '__main__': |
|
myapp.run(host="0.0.0.0", port=8080, debug=True) |