Spaces:
Running
Running
from flask import Flask, request, jsonify, send_file, render_template_string, make_response | |
from deep_translator import GoogleTranslator | |
from PIL import Image | |
import torch | |
from diffusers import StableDiffusionPipeline | |
import random | |
import io | |
import os | |
app = Flask(__name__) | |
MODEL_NAME = "Ojimi/anime-kawai-diffusion" | |
MODEL_DIR = "./models/anime-kawai-diffusion" # Directory to store the model | |
# Download and load the model at startup | |
def load_model(): | |
if not os.path.exists(MODEL_DIR): | |
print(f"Downloading the model {MODEL_NAME}...") | |
pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) | |
pipeline.save_pretrained(MODEL_DIR) | |
else: | |
print(f"Loading model from {MODEL_DIR}...") | |
pipeline = StableDiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.float16) | |
if torch.cuda.is_available(): | |
pipeline.to("cuda") | |
print("Model loaded on GPU") | |
else: | |
print("GPU not available. Running on CPU.") | |
return pipeline | |
# Load the model once during startup | |
pipeline = load_model() | |
# HTML template for the index page | |
index_html = """ | |
<!DOCTYPE html> | |
<html lang="ja"> | |
<head> | |
<title>Kawaii Diffusion</title> | |
</head> | |
<body> | |
<h1>Kawaii Diffusion Image Generator</h1> | |
<form action="/generate" method="get"> | |
<label for="prompt">Prompt:</label> | |
<input type="text" id="prompt" name="prompt" required><br><br> | |
<button type="submit">Generate Image</button> | |
</form> | |
</body> | |
</html> | |
""" | |
def index(): | |
return render_template_string(index_html) | |
# Function to generate image locally | |
def generate_image_locally(prompt, steps=35, cfg_scale=7, width=512, height=512, seed=-1): | |
# Translate prompt from Russian to English | |
prompt = GoogleTranslator(source='ru', target='en').translate(prompt) | |
print(f'Translated prompt: {prompt}') | |
# Set a random seed if not provided | |
generator = torch.manual_seed(seed if seed != -1 else random.randint(1, 1_000_000)) | |
# Generate the image using the loaded pipeline | |
image = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator).images[0] | |
return image | |
def generate_image(): | |
try: | |
prompt = request.args.get("prompt", "") | |
steps = int(request.args.get("steps", 35)) | |
cfg_scale = float(request.args.get("cfgs", 7)) | |
width = int(request.args.get("width", 512)) | |
height = int(request.args.get("height", 512)) | |
seed = int(request.args.get("seed", -1)) | |
# Generate the image locally | |
image = generate_image_locally(prompt, steps, cfg_scale, width, height, seed) | |
# Save the image to a BytesIO object | |
img_bytes = io.BytesIO() | |
image.save(img_bytes, format='PNG') | |
img_bytes.seek(0) | |
return send_file(img_bytes, mimetype='image/png') | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
# Content-Security-Policyヘッダーを設定する | |
def add_security_headers(response): | |
response.headers['Content-Security-Policy'] = ( | |
"default-src 'self'; " | |
"img-src 'self' data:; " | |
"style-src 'self' 'unsafe-inline'; " | |
"script-src 'self' 'unsafe-inline'; " | |
) | |
return response | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=7860) | |