File size: 8,610 Bytes
2b8b77d
 
1eb5467
e1fcf74
2b8b77d
 
 
 
e2371c5
16908f1
3c2c999
003a054
 
2d475e1
 
c6dfa2b
2d475e1
e1fcf74
1314d69
f54a55a
 
2d475e1
 
292c38f
e1fcf74
140db10
 
3d9ac9f
 
 
 
 
 
7d4e0fa
3d9ac9f
 
7b28dab
b5d5466
 
f8e56c2
 
 
 
 
 
 
 
 
db1abac
 
b777a65
 
f8e56c2
140db10
f217e4d
9a397ea
 
140db10
 
 
718ba97
00fc70b
f217e4d
1eb5467
0cbf06a
ca9e441
 
0cbf06a
3d9ac9f
c464ec4
e1ad51f
 
9d731d3
0cbf06a
 
 
 
 
 
f40fb7c
cd2465c
5e49d53
 
 
 
81ffcd6
dc2976a
60fee97
2d475e1
f217e4d
3d9ac9f
7c696fc
 
3d9ac9f
 
 
 
 
 
 
718ba97
 
1314d69
9d731d3
29017ec
 
9d731d3
10c555d
29017ec
 
 
 
 
 
 
 
 
 
 
 
 
b777a65
3664c52
e32f66c
b777a65
10c555d
9d731d3
b777a65
0cbf06a
29017ec
f217e4d
 
 
3d9ac9f
d5a8945
7b9e6e4
718ba97
 
f217e4d
ccc38b8
 
145506a
140db10
 
 
 
db1abac
e514550
ccc38b8
 
a9bef22
7d4e0fa
3d9ac9f
5e49d53
7d4e0fa
132c798
 
7d4e0fa
ccc38b8
10c555d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4e0fa
e514550
 
7c696fc
 
 
 
 
10c555d
e514550
7c696fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164edec
cd2465c
50d6862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import gradio as gr
import spaces
from clip_slider_pipeline import CLIPSliderFlux
from diffusers import FluxPipeline, AutoencoderTiny
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import load_image
from diffusers.utils import export_to_gif
import random

# load pipelines
base_model = "black-forest-labs/FLUX.1-schnell"

taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxPipeline.from_pretrained(base_model,
                                    vae=taef1,
                                    torch_dtype=torch.bfloat16)

pipe.transformer.to(memory_format=torch.channels_last)
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# pipe.enable_model_cpu_offload()
clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))

MAX_SEED = 2**32-1

def convert_to_centered_scale(num):
    if num % 2 == 0:  # even
        start = -(num // 2 - 1)
        end = num // 2
    else:  # odd
        start = -(num // 2)
        end = num // 2 
    return tuple(range(start, end + 1))

@spaces.GPU(duration=200)
def generate(prompt,
             concept_1,
             concept_2,
             scale,
             randomize_seed=True,
             seed=42,
             recalc_directions=True,
             iterations=200, 
             steps=3, 
             interm_steps=21, 
             guidance_scale=3.5,
             x_concept_1="", x_concept_2="", 
             avg_diff_x=None, 
             total_images=[],
             progress=gr.Progress(track_tqdm=True)
    ):
    slider_x = [concept_2, concept_1]
    # check if avg diff for directions need to be re-calculated
    print("slider_x", slider_x)
    print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
    if randomize_seed:
            seed = random.randint(0, MAX_SEED)
        
    if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
        avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
        x_concept_1, x_concept_2 = slider_x[0], slider_x[1]

    images = []
    high_scale = scale
    low_scale = -1 * scale
    for i in range(interm_steps):
        cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
        image = clip_slider.generate(prompt, 
                                     width=768,
                                     height=768,
                                     guidance_scale=guidance_scale, 
                                     scale=cur_scale,  seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
        images.append(image)
    canvas = Image.new('RGB', (256*interm_steps, 256))
    for i, im in enumerate(images):
        canvas.paste(im.resize((256,256)), (256 * i, 0))

    comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"

    scale_total = convert_to_centered_scale(interm_steps)
    scale_min = scale_total[0]
    scale_max = scale_total[-1]
    scale_middle = scale_total.index(0)
    post_generation_slider_update = gr.update(label=comma_concepts_x, value=0, minimum=scale_min, maximum=scale_max, interactive=True)
    avg_diff_x = avg_diff.cpu()
    print(images)
    return x_concept_1,x_concept_2, avg_diff_x, export_to_gif(images, "clip.gif", fps=5), canvas, images, images[scale_middle], post_generation_slider_update, seed

def update_pre_generated_images(slider_value, total_images):
    print(total_images)
    print(slider_value)
    number_images = len(total_images)
    if(number_images > 0):
        scale_tuple = convert_to_centered_scale(number_images)
        return total_images[scale_tuple.index(slider_value)]
    else:
        return None
    
def reset_recalc_directions():
    return True


intro = """
<div style="display: flex;align-items: center;justify-content: center">
    <img src="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux/resolve/main/Group 4-16.png" width="120" style="display: inline-block">
    <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block;font-size:1.75em">Latent Navigation</h1>
</div>
<div style="display: flex;align-items: center;justify-content: center">
    <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Exploring CLIP text space with FLUX.1 schnell πŸͺ</h3>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
    <a href="https://github.com/linoytsaban/semantic-sliders" target="_blank">code</a>
     | 
    <a href="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux?duplicate=true" target="_blank" style="
        display: inline-block;
    ">
    <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
css='''
#strip, #gif{max-height: 170px; min-height: 65px}
#strip img{object-fit: cover}
'''
examples = [["a dog in the park", "winter", "summer", 1.25], ["a house", "USA suburb", "Europe", 2], ["a tomato", "rotten", "super fresh", 2]]

with gr.Blocks(css=css) as demo:

    gr.HTML(intro)
    
    x_concept_1 = gr.State("")
    x_concept_2 = gr.State("")
    total_images = gr.State([])

    avg_diff_x = gr.State()

    recalc_directions = gr.State(False)
    
    with gr.Row():
        with gr.Column():
            with gr.Row():
                concept_1 = gr.Textbox(label="1st direction to steer", placeholder="winter")
                concept_2 = gr.Textbox(label="2nd direction to steer", placeholder="summer")
            prompt = gr.Textbox(label="Prompt", info="Describe what you to be steered by the directions", placeholder="A dog in the park")
            x = gr.Slider(minimum=0, value=1.5, step=0.1, maximum=4.0, label="Strength", info="maximum strength on each direction (unstable beyond 2.5)")
            submit = gr.Button("Generate directions")
            
        with gr.Column():
            with gr.Group(elem_id="group"):
                post_generation_image = gr.Image(label="Generated Images", type="filepath")
                post_generation_slider = gr.Slider(minimum=-10, maximum=10, value=0, step=1)
            with gr.Row():
                with gr.Column(scale=4, min_width=50):
                    image_seq = gr.Image(label="Strip", elem_id="strip", height=65)
                    
                with gr.Column(scale=2, min_width=50):
                    output_image = gr.Image(label="Gif", elem_id="gif")
    
    with gr.Accordion(label="Advanced options", open=False):
        interm_steps = gr.Slider(label = "Num of intermediate images", minimum=3, value=21, maximum=65, step=2)
        with gr.Row():
            iterations = gr.Slider(label = "Num iterations for clip directions", minimum=0, value=200, maximum=500, step=1)
            steps = gr.Slider(label = "Num inference steps", minimum=1, value=3, maximum=8, step=1)
        with gr.Row():
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.1,
                maximum=10.0,
                step=0.1,
                value=3.5,
            )
            with gr.Column():
                randomize_seed = gr.Checkbox(True, label="Randomize seed")
                seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, label="Seed", interactive=True, randomize=True)

    examples_gradio = gr.Examples(
        examples=examples,
        inputs=[prompt, concept_1, concept_2, x],
        fn=generate,
        outputs=[x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images, post_generation_image, post_generation_slider, seed],
        cache_examples="lazy"
    )

    submit.click(
        fn=generate,
        inputs=[prompt, concept_1, concept_2, x, randomize_seed, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x, total_images],
        outputs=[x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images, post_generation_image, post_generation_slider, seed]
    )
    iterations.change(
        fn=reset_recalc_directions,
        outputs=[recalc_directions]
    )
    seed.change(
        fn=reset_recalc_directions,
        outputs=[recalc_directions]
    )
    post_generation_slider.change(
        fn=update_pre_generated_images,
        inputs=[post_generation_slider, total_images],
        outputs=[post_generation_image],
        queue=False,
        show_progress="hidden",
        concurrency_limit=None
    )
        
if __name__ == "__main__":
    demo.launch()