File size: 3,209 Bytes
176edce
 
 
 
 
 
 
 
 
 
 
 
 
343fdaf
 
176edce
343fdaf
176edce
343fdaf
176edce
 
 
343fdaf
176edce
 
 
343fdaf
176edce
 
 
343fdaf
176edce
 
343fdaf
176edce
 
 
 
343fdaf
176edce
 
 
 
0697a48
176edce
 
b0ebe32
0697a48
176edce
 
 
 
 
0697a48
343fdaf
176edce
0697a48
176edce
 
 
0697a48
176edce
b0ebe32
 
176edce
 
 
343fdaf
0697a48
343fdaf
176edce
 
343fdaf
176edce
343fdaf
176edce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import spaces
import argparse
import os
import time
from os import path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

import gradio as gr
import torch
from diffusers import FluxPipeline

torch.backends.cuda.matmul.allow_tf32 = True

class timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name

    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")

    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

if not path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)

with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                # num_images = gr.Slider(label="Number of Images", minimum=1, maximum=2, step=1, value=1, interactive=True)
                height = gr.Number(label="Image Height", value=1024, interactive=True)
                width = gr.Number(label="Image Width", value=1024, interactive=True)
                steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8, interactive=True)
                scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, value=3.5, interactive=True)
                # eta = gr.Number(label="Eta (Corresponds to parameter eta (η) in the DDIM paper, i.e. 0.0 eqauls DDIM, 1.0 equals LCM)", value=1., interactive=True)
                prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
                seed = gr.Number(label="Seed", value=3413, interactive=True)
                btn = gr.Button(value="run")
            with gr.Column():
                output = gr.Gallery(height=768)

            @spaces.GPU
            def process_image(height, width, steps, scales, prompt, seed):
                global pipe
                with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
                    return pipe(
                        prompt=[prompt],
                        generator=torch.Generator().manual_seed(int(seed)),
                        num_inference_steps=steps,
                        guidance_scale=scales,
                        height=int(height),
                        width=int(width)
                    ).images

            reactive_controls = [height, width, steps, scales, prompt, seed]

            # for control in reactive_controls:
            #     control.change(fn=process_image, inputs=reactive_controls, outputs=[output])

            btn.click(process_image, inputs=reactive_controls, outputs=[output])

if __name__ == "__main__":
    demo.launch()