RohitGandikota commited on
Commit
3bf9025
·
verified ·
1 Parent(s): 4cbd4f2

adding slider edits

Browse files
Files changed (1) hide show
  1. app.py +87 -13
app.py CHANGED
@@ -8,6 +8,9 @@ import torch
8
  from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
9
  from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
 
 
 
11
 
12
  model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
13
  repo_name = "tianweiy/DMD2"
@@ -27,6 +30,26 @@ pipe = DiffusionPipeline.from_pretrained(model_repo_id, unet=unet, torch_dtype=t
27
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1024
@@ -42,25 +65,49 @@ def infer(
42
  height,
43
  guidance_scale,
44
  num_inference_steps,
 
 
45
  progress=gr.Progress(track_tqdm=True),
46
  ):
47
  if randomize_seed:
48
  seed = random.randint(0, MAX_SEED)
49
 
50
- generator = torch.Generator().manual_seed(seed)
 
51
 
52
- # with network:
 
 
 
 
 
 
 
53
  image = pipe(
54
- prompt=prompt,
55
- negative_prompt=negative_prompt,
56
- guidance_scale=guidance_scale,
57
- num_inference_steps=num_inference_steps,
58
- width=width,
59
- height=height,
60
- generator=generator,
61
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- return image, seed
 
64
 
65
 
66
  examples = [
@@ -91,7 +138,32 @@ with gr.Blocks(css=css) as demo:
91
 
92
  run_button = gr.Button("Run", scale=0, variant="primary")
93
 
94
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  with gr.Accordion("Advanced Settings", open=False):
97
  negative_prompt = gr.Text(
@@ -158,8 +230,10 @@ with gr.Blocks(css=css) as demo:
158
  height,
159
  guidance_scale,
160
  num_inference_steps,
 
 
161
  ],
162
- outputs=[result, seed],
163
  )
164
 
165
  if __name__ == "__main__":
 
8
  from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
9
  from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
11
+ import sys
12
+ sys.path.append('.')
13
+ from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
14
 
15
  model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
16
  repo_name = "tianweiy/DMD2"
 
30
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
31
 
32
 
33
+ unet = pipe.unet
34
+
35
+ ## Change these parameters based on how you trained your sliderspace sliders
36
+ train_method = 'xattn-strict'
37
+ rank = 1
38
+ alpha =1
39
+ networks = {}
40
+ modules = DEFAULT_TARGET_REPLACE
41
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
42
+ for i in range(1):
43
+ networks[i] = LoRANetwork(
44
+ unet,
45
+ rank=int(rank),
46
+ multiplier=1.0,
47
+ alpha=int(alpha),
48
+ train_method=train_method,
49
+ fast_init=True,
50
+ ).to(device, dtype=torch_dtype)
51
+
52
+
53
 
54
  MAX_SEED = np.iinfo(np.int32).max
55
  MAX_IMAGE_SIZE = 1024
 
65
  height,
66
  guidance_scale,
67
  num_inference_steps,
68
+ sliderspace_path,
69
+ slider_scale,
70
  progress=gr.Progress(track_tqdm=True),
71
  ):
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
74
 
75
+ for net in networks:
76
+ networks[net].load_state_dict(torch.load(sliderspace_path))
77
 
78
+ for net in networks:
79
+ networks[net].set_lora_slider(slider_scale)
80
+
81
+ with networks[0]:
82
+ pass
83
+
84
+ # original image
85
+ generator = torch.Generator().manual_seed(seed)
86
  image = pipe(
87
+ prompt=prompt,
88
+ negative_prompt=negative_prompt,
89
+ guidance_scale=guidance_scale,
90
+ num_inference_steps=num_inference_steps,
91
+ width=width,
92
+ height=height,
93
+ generator=generator,
94
+ ).images[0]
95
+
96
+ # edited image
97
+ generator = torch.Generator().manual_seed(seed)
98
+ with networks[0]:
99
+ slider_image = pipe(
100
+ prompt=prompt,
101
+ negative_prompt=negative_prompt,
102
+ guidance_scale=guidance_scale,
103
+ num_inference_steps=num_inference_steps,
104
+ width=width,
105
+ height=height,
106
+ generator=generator,
107
+ ).images[0]
108
 
109
+
110
+ return image, slider_image, seed
111
 
112
 
113
  examples = [
 
138
 
139
  run_button = gr.Button("Run", scale=0, variant="primary")
140
 
141
+
142
+ # New dropdowns side by side
143
+ with gr.Row():
144
+ slider_space = gr.Dropdown(
145
+ choices=["spaceship", "car", "person"],
146
+ label="SliderSpace",
147
+ value="spaceship"
148
+ )
149
+ discovered_directions = gr.Dropdown(
150
+ choices=[f"Slider {i}" for i in range(1, 11)],
151
+ label="Discovered Directions",
152
+ value="Slider 1"
153
+ )
154
+
155
+ slider_scale = gr.Slider(
156
+ label="Slider Scale",
157
+ minimum=-2,
158
+ maximum=2,
159
+ step=0.1,
160
+ value=1,
161
+ )
162
+
163
+ with gr.Row():
164
+ result = gr.Image(label="Original Image", show_label=True)
165
+ slider_result = gr.Image(label="Discovered Edit Direction", show_label=True)
166
+
167
 
168
  with gr.Accordion("Advanced Settings", open=False):
169
  negative_prompt = gr.Text(
 
230
  height,
231
  guidance_scale,
232
  num_inference_steps,
233
+ f"sliderspace_weights/{slider_space}/{discovered_directions}",
234
+ slider_scale
235
  ],
236
+ outputs=[result, slider_result, seed],
237
  )
238
 
239
  if __name__ == "__main__":