File size: 3,438 Bytes
a2384b5
435f67e
6a3ef3a
 
 
 
 
 
435f67e
 
 
6a3ef3a
 
435f67e
6a3ef3a
 
 
 
 
 
 
 
 
435f67e
6a3ef3a
 
 
 
 
95d6160
6a3ef3a
435f67e
6a3ef3a
 
1fb1361
700fd72
 
dbf2966
 
6a3ef3a
 
 
 
 
 
 
 
 
 
 
dbf2966
435f67e
 
 
 
 
 
6a3ef3a
 
 
 
 
 
 
 
 
 
 
 
 
a2384b5
95d6160
6a3ef3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2384b5
 
6a3ef3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from flask import Flask, request, jsonify, send_file, render_template_string, make_response
from deep_translator import GoogleTranslator
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline
import random
import io
import os

app = Flask(__name__)

MODEL_NAME = "Ojimi/anime-kawai-diffusion"
MODEL_DIR = "./models/anime-kawai-diffusion"  # Directory to store the model

# Download and load the model at startup
def load_model():
    if not os.path.exists(MODEL_DIR):
        print(f"Downloading the model {MODEL_NAME}...")
        pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
        pipeline.save_pretrained(MODEL_DIR)
    else:
        print(f"Loading model from {MODEL_DIR}...")
        pipeline = StableDiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.float16)

    if torch.cuda.is_available():
        pipeline.to("cuda")
        print("Model loaded on GPU")
    else:
        print("GPU not available. Running on CPU.")
    
    return pipeline

# Load the model once during startup
pipeline = load_model()

# HTML template for the index page
index_html = """
<!DOCTYPE html>
<html lang="ja">
<head>
    <title>Kawaii Diffusion</title>
</head>
<body>
    <h1>Kawaii Diffusion Image Generator</h1>
    <form action="/generate" method="get">
        <label for="prompt">Prompt:</label>
        <input type="text" id="prompt" name="prompt" required><br><br>
        <button type="submit">Generate Image</button>
    </form>
</body>
</html>
"""

@app.route('/')
def index():
    return render_template_string(index_html)

# Function to generate image locally
def generate_image_locally(prompt, steps=35, cfg_scale=7, width=512, height=512, seed=-1):
    # Translate prompt from Russian to English
    prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
    print(f'Translated prompt: {prompt}')

    # Set a random seed if not provided
    generator = torch.manual_seed(seed if seed != -1 else random.randint(1, 1_000_000))

    # Generate the image using the loaded pipeline
    image = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator).images[0]
    return image

@app.route('/generate', methods=['GET'])
def generate_image():
    try:
        prompt = request.args.get("prompt", "")
        steps = int(request.args.get("steps", 35))
        cfg_scale = float(request.args.get("cfgs", 7))
        width = int(request.args.get("width", 512))
        height = int(request.args.get("height", 512))
        seed = int(request.args.get("seed", -1))

        # Generate the image locally
        image = generate_image_locally(prompt, steps, cfg_scale, width, height, seed)

        # Save the image to a BytesIO object
        img_bytes = io.BytesIO()
        image.save(img_bytes, format='PNG')
        img_bytes.seek(0)

        return send_file(img_bytes, mimetype='image/png')
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Content-Security-Policyヘッダーを設定する
@app.after_request
def add_security_headers(response):
    response.headers['Content-Security-Policy'] = (
        "default-src 'self'; "
        "img-src 'self' data:; "
        "style-src 'self' 'unsafe-inline'; "
        "script-src 'self' 'unsafe-inline'; "
    )
    return response

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)