Ashoka74 commited on
Commit
f0943f5
Β·
verified Β·
1 Parent(s): e2fdc43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1103 -0
app.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import safetensors.torch as sf
7
+ import db_examples
8
+ import datetime
9
+ from pathlib import Path
10
+ from io import BytesIO
11
+
12
+ from PIL import Image
13
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
14
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+ from briarmbg import BriaRMBG
18
+ from enum import Enum
19
+ from torch.hub import download_url_to_file
20
+
21
+ from torch.hub import download_url_to_file
22
+ import cv2
23
+
24
+ from typing import Optional
25
+
26
+ from Depth.depth_anything_v2.dpt import DepthAnythingV2
27
+
28
+
29
+
30
+ # from FLORENCE
31
+ import spaces
32
+ import supervision as sv
33
+ import torch
34
+ from PIL import Image
35
+
36
+ from utils.sam import load_sam_image_model, run_sam_inference
37
+
38
+
39
+ try:
40
+ import xformers
41
+ import xformers.ops
42
+ XFORMERS_AVAILABLE = True
43
+ print("xformers is available - Using memory efficient attention")
44
+ except ImportError:
45
+ XFORMERS_AVAILABLE = False
46
+ print("xformers not available - Using default attention")
47
+
48
+ # Memory optimizations for RTX 2070
49
+ torch.backends.cudnn.benchmark = True
50
+ if torch.cuda.is_available():
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ # Set a smaller attention slice size for RTX 2070
54
+ torch.backends.cuda.max_split_size_mb = 512
55
+ device = torch.device('cuda')
56
+ else:
57
+ device = torch.device('cpu')
58
+
59
+ # 'stablediffusionapi/realistic-vision-v51'
60
+ # 'runwayml/stable-diffusion-v1-5'
61
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
62
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
63
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
64
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
65
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
66
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
67
+
68
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
69
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
70
+ model = model.to(device)
71
+ model.eval()
72
+
73
+ # Change UNet
74
+
75
+ with torch.no_grad():
76
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
77
+ new_conv_in.weight.zero_()
78
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
79
+ new_conv_in.bias = unet.conv_in.bias
80
+ unet.conv_in = new_conv_in
81
+
82
+
83
+ unet_original_forward = unet.forward
84
+
85
+
86
+ def enable_efficient_attention():
87
+ if XFORMERS_AVAILABLE:
88
+ try:
89
+ # RTX 2070 specific settings
90
+ unet.set_use_memory_efficient_attention_xformers(True)
91
+ vae.set_use_memory_efficient_attention_xformers(True)
92
+ print("Enabled xformers memory efficient attention")
93
+ except Exception as e:
94
+ print(f"Xformers error: {e}")
95
+ print("Falling back to sliced attention")
96
+ # Use sliced attention for RTX 2070
97
+ unet.set_attention_slice_size(4)
98
+ vae.set_attention_slice_size(4)
99
+ unet.set_attn_processor(AttnProcessor2_0())
100
+ vae.set_attn_processor(AttnProcessor2_0())
101
+ else:
102
+ # Fallback for when xformers is not available
103
+ print("Using sliced attention")
104
+ unet.set_attention_slice_size(4)
105
+ vae.set_attention_slice_size(4)
106
+ unet.set_attn_processor(AttnProcessor2_0())
107
+ vae.set_attn_processor(AttnProcessor2_0())
108
+
109
+ # Add memory clearing function
110
+ def clear_memory():
111
+ if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
113
+ torch.cuda.synchronize()
114
+
115
+ # Enable efficient attention
116
+ enable_efficient_attention()
117
+
118
+
119
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
120
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
121
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
122
+ new_sample = torch.cat([sample, c_concat], dim=1)
123
+ kwargs['cross_attention_kwargs'] = {}
124
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
125
+
126
+
127
+ unet.forward = hooked_unet_forward
128
+
129
+ # Load
130
+
131
+ model_path = './models/iclight_sd15_fc.safetensors'
132
+ # model_path = './models/iclight_sd15_fbc.safetensors'
133
+
134
+
135
+ # if not os.path.exists(model_path):
136
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
137
+
138
+ sd_offset = sf.load_file(model_path)
139
+ sd_origin = unet.state_dict()
140
+ keys = sd_origin.keys()
141
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
142
+ unet.load_state_dict(sd_merged, strict=True)
143
+ del sd_offset, sd_origin, sd_merged, keys
144
+
145
+ # Device
146
+
147
+ # device = torch.device('cuda')
148
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
149
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
150
+ # unet = unet.to(device=device, dtype=torch.float16)
151
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
152
+
153
+
154
+ # Device and dtype setup
155
+ device = torch.device('cuda')
156
+ dtype = torch.float16 # RTX 2070 works well with float16
157
+
158
+ # Memory optimizations for RTX 2070
159
+ torch.backends.cudnn.benchmark = True
160
+ if torch.cuda.is_available():
161
+ torch.backends.cuda.matmul.allow_tf32 = True
162
+ torch.backends.cudnn.allow_tf32 = True
163
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
164
+ torch.backends.cuda.max_split_size_mb = 128
165
+
166
+ # Move models to device with consistent dtype
167
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
168
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
169
+ unet = unet.to(device=device, dtype=dtype)
170
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
171
+
172
+
173
+ ddim_scheduler = DDIMScheduler(
174
+ num_train_timesteps=1000,
175
+ beta_start=0.00085,
176
+ beta_end=0.012,
177
+ beta_schedule="scaled_linear",
178
+ clip_sample=False,
179
+ set_alpha_to_one=False,
180
+ steps_offset=1,
181
+ )
182
+
183
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
184
+ num_train_timesteps=1000,
185
+ beta_start=0.00085,
186
+ beta_end=0.012,
187
+ steps_offset=1
188
+ )
189
+
190
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
191
+ num_train_timesteps=1000,
192
+ beta_start=0.00085,
193
+ beta_end=0.012,
194
+ algorithm_type="sde-dpmsolver++",
195
+ use_karras_sigmas=True,
196
+ steps_offset=1
197
+ )
198
+
199
+ # Pipelines
200
+
201
+ t2i_pipe = StableDiffusionPipeline(
202
+ vae=vae,
203
+ text_encoder=text_encoder,
204
+ tokenizer=tokenizer,
205
+ unet=unet,
206
+ scheduler=dpmpp_2m_sde_karras_scheduler,
207
+ safety_checker=None,
208
+ requires_safety_checker=False,
209
+ feature_extractor=None,
210
+ image_encoder=None
211
+ )
212
+
213
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
214
+ vae=vae,
215
+ text_encoder=text_encoder,
216
+ tokenizer=tokenizer,
217
+ unet=unet,
218
+ scheduler=dpmpp_2m_sde_karras_scheduler,
219
+ safety_checker=None,
220
+ requires_safety_checker=False,
221
+ feature_extractor=None,
222
+ image_encoder=None
223
+ )
224
+
225
+
226
+ @torch.inference_mode()
227
+ def encode_prompt_inner(txt: str):
228
+ max_length = tokenizer.model_max_length
229
+ chunk_length = tokenizer.model_max_length - 2
230
+ id_start = tokenizer.bos_token_id
231
+ id_end = tokenizer.eos_token_id
232
+ id_pad = id_end
233
+
234
+ def pad(x, p, i):
235
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
236
+
237
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
238
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
239
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
240
+
241
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
242
+ conds = text_encoder(token_ids).last_hidden_state
243
+
244
+ return conds
245
+
246
+
247
+ @torch.inference_mode()
248
+ def encode_prompt_pair(positive_prompt, negative_prompt):
249
+ c = encode_prompt_inner(positive_prompt)
250
+ uc = encode_prompt_inner(negative_prompt)
251
+
252
+ c_len = float(len(c))
253
+ uc_len = float(len(uc))
254
+ max_count = max(c_len, uc_len)
255
+ c_repeat = int(math.ceil(max_count / c_len))
256
+ uc_repeat = int(math.ceil(max_count / uc_len))
257
+ max_chunk = max(len(c), len(uc))
258
+
259
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
260
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
261
+
262
+ c = torch.cat([p[None, ...] for p in c], dim=1)
263
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
264
+
265
+ return c, uc
266
+
267
+
268
+ @torch.inference_mode()
269
+ def pytorch2numpy(imgs, quant=True):
270
+ results = []
271
+ for x in imgs:
272
+ y = x.movedim(0, -1)
273
+
274
+ if quant:
275
+ y = y * 127.5 + 127.5
276
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
277
+ else:
278
+ y = y * 0.5 + 0.5
279
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
280
+
281
+ results.append(y)
282
+ return results
283
+
284
+
285
+ @torch.inference_mode()
286
+ def numpy2pytorch(imgs):
287
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
288
+ h = h.movedim(-1, 1)
289
+ return h
290
+
291
+
292
+ def resize_and_center_crop(image, target_width, target_height):
293
+ pil_image = Image.fromarray(image)
294
+ original_width, original_height = pil_image.size
295
+ scale_factor = max(target_width / original_width, target_height / original_height)
296
+ resized_width = int(round(original_width * scale_factor))
297
+ resized_height = int(round(original_height * scale_factor))
298
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
299
+ left = (resized_width - target_width) / 2
300
+ top = (resized_height - target_height) / 2
301
+ right = (resized_width + target_width) / 2
302
+ bottom = (resized_height + target_height) / 2
303
+ cropped_image = resized_image.crop((left, top, right, bottom))
304
+ return np.array(cropped_image)
305
+
306
+
307
+ def resize_without_crop(image, target_width, target_height):
308
+ pil_image = Image.fromarray(image)
309
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
310
+ return np.array(resized_image)
311
+
312
+
313
+ @torch.inference_mode()
314
+ def run_rmbg(img, sigma=0.0):
315
+ # Convert RGBA to RGB if needed
316
+ if img.shape[-1] == 4:
317
+ # Use white background for alpha composition
318
+ alpha = img[..., 3:] / 255.0
319
+ rgb = img[..., :3]
320
+ white_bg = np.ones_like(rgb) * 255
321
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
322
+
323
+ H, W, C = img.shape
324
+ assert C == 3
325
+ k = (256.0 / float(H * W)) ** 0.5
326
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
327
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
328
+ alpha = rmbg(feed)[0][0]
329
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
330
+ alpha = alpha.movedim(1, -1)[0]
331
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
332
+
333
+ # Create RGBA image
334
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
335
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
336
+ return result.clip(0, 255).astype(np.uint8), rgba
337
+ @torch.inference_mode()
338
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
+ clear_memory()
340
+
341
+ # Get input dimensions
342
+ input_height, input_width = input_fg.shape[:2]
343
+
344
+ bg_source = BGSource(bg_source)
345
+
346
+
347
+ if bg_source == BGSource.UPLOAD:
348
+ pass
349
+ elif bg_source == BGSource.UPLOAD_FLIP:
350
+ input_bg = np.fliplr(input_bg)
351
+ elif bg_source == BGSource.GREY:
352
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
353
+ elif bg_source == BGSource.LEFT:
354
+ gradient = np.linspace(255, 0, input_width)
355
+ image = np.tile(gradient, (input_height, 1))
356
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
357
+ elif bg_source == BGSource.RIGHT:
358
+ gradient = np.linspace(0, 255, input_width)
359
+ image = np.tile(gradient, (input_height, 1))
360
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
361
+ elif bg_source == BGSource.TOP:
362
+ gradient = np.linspace(255, 0, input_height)[:, None]
363
+ image = np.tile(gradient, (1, input_width))
364
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
365
+ elif bg_source == BGSource.BOTTOM:
366
+ gradient = np.linspace(0, 255, input_height)[:, None]
367
+ image = np.tile(gradient, (1, input_width))
368
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
369
+ else:
370
+ raise 'Wrong initial latent!'
371
+
372
+ rng = torch.Generator(device=device).manual_seed(int(seed))
373
+
374
+ # Use input dimensions directly
375
+ fg = resize_without_crop(input_fg, input_width, input_height)
376
+
377
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
378
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
379
+
380
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
381
+
382
+ if input_bg is None:
383
+ latents = t2i_pipe(
384
+ prompt_embeds=conds,
385
+ negative_prompt_embeds=unconds,
386
+ width=input_width,
387
+ height=input_height,
388
+ num_inference_steps=steps,
389
+ num_images_per_prompt=num_samples,
390
+ generator=rng,
391
+ output_type='latent',
392
+ guidance_scale=cfg,
393
+ cross_attention_kwargs={'concat_conds': concat_conds},
394
+ ).images.to(vae.dtype) / vae.config.scaling_factor
395
+ else:
396
+ bg = resize_without_crop(input_bg, input_width, input_height)
397
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
398
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
399
+ latents = i2i_pipe(
400
+ image=bg_latent,
401
+ strength=lowres_denoise,
402
+ prompt_embeds=conds,
403
+ negative_prompt_embeds=unconds,
404
+ width=input_width,
405
+ height=input_height,
406
+ num_inference_steps=int(round(steps / lowres_denoise)),
407
+ num_images_per_prompt=num_samples,
408
+ generator=rng,
409
+ output_type='latent',
410
+ guidance_scale=cfg,
411
+ cross_attention_kwargs={'concat_conds': concat_conds},
412
+ ).images.to(vae.dtype) / vae.config.scaling_factor
413
+
414
+ pixels = vae.decode(latents).sample
415
+ pixels = pytorch2numpy(pixels)
416
+ pixels = [resize_without_crop(
417
+ image=p,
418
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
419
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
420
+ for p in pixels]
421
+
422
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
423
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
424
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
425
+
426
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
427
+
428
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
429
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
430
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
431
+
432
+ latents = i2i_pipe(
433
+ image=latents,
434
+ strength=highres_denoise,
435
+ prompt_embeds=conds,
436
+ negative_prompt_embeds=unconds,
437
+ width=highres_width,
438
+ height=highres_height,
439
+ num_inference_steps=int(round(steps / highres_denoise)),
440
+ num_images_per_prompt=num_samples,
441
+ generator=rng,
442
+ output_type='latent',
443
+ guidance_scale=cfg,
444
+ cross_attention_kwargs={'concat_conds': concat_conds},
445
+ ).images.to(vae.dtype) / vae.config.scaling_factor
446
+
447
+ pixels = vae.decode(latents).sample
448
+ pixels = pytorch2numpy(pixels)
449
+
450
+ # Resize back to input dimensions
451
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
452
+ pixels = np.stack(pixels)
453
+
454
+ return pixels
455
+
456
+ @torch.inference_mode()
457
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
458
+ clear_memory()
459
+ bg_source = BGSource(bg_source)
460
+
461
+ if bg_source == BGSource.UPLOAD:
462
+ pass
463
+ elif bg_source == BGSource.UPLOAD_FLIP:
464
+ input_bg = np.fliplr(input_bg)
465
+ elif bg_source == BGSource.GREY:
466
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
467
+ elif bg_source == BGSource.LEFT:
468
+ gradient = np.linspace(224, 32, image_width)
469
+ image = np.tile(gradient, (image_height, 1))
470
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
471
+ elif bg_source == BGSource.RIGHT:
472
+ gradient = np.linspace(32, 224, image_width)
473
+ image = np.tile(gradient, (image_height, 1))
474
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
475
+ elif bg_source == BGSource.TOP:
476
+ gradient = np.linspace(224, 32, image_height)[:, None]
477
+ image = np.tile(gradient, (1, image_width))
478
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
479
+ elif bg_source == BGSource.BOTTOM:
480
+ gradient = np.linspace(32, 224, image_height)[:, None]
481
+ image = np.tile(gradient, (1, image_width))
482
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
483
+ else:
484
+ raise 'Wrong background source!'
485
+
486
+ rng = torch.Generator(device=device).manual_seed(seed)
487
+
488
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
489
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
490
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
491
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
492
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
493
+
494
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
495
+
496
+ latents = t2i_pipe(
497
+ prompt_embeds=conds,
498
+ negative_prompt_embeds=unconds,
499
+ width=image_width,
500
+ height=image_height,
501
+ num_inference_steps=steps,
502
+ num_images_per_prompt=num_samples,
503
+ generator=rng,
504
+ output_type='latent',
505
+ guidance_scale=cfg,
506
+ cross_attention_kwargs={'concat_conds': concat_conds},
507
+ ).images.to(vae.dtype) / vae.config.scaling_factor
508
+
509
+ pixels = vae.decode(latents).sample
510
+ pixels = pytorch2numpy(pixels)
511
+ pixels = [resize_without_crop(
512
+ image=p,
513
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
514
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
515
+ for p in pixels]
516
+
517
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
518
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
519
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
520
+
521
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
522
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
523
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
524
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
525
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
526
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
527
+
528
+ latents = i2i_pipe(
529
+ image=latents,
530
+ strength=highres_denoise,
531
+ prompt_embeds=conds,
532
+ negative_prompt_embeds=unconds,
533
+ width=image_width,
534
+ height=image_height,
535
+ num_inference_steps=int(round(steps / highres_denoise)),
536
+ num_images_per_prompt=num_samples,
537
+ generator=rng,
538
+ output_type='latent',
539
+ guidance_scale=cfg,
540
+ cross_attention_kwargs={'concat_conds': concat_conds},
541
+ ).images.to(vae.dtype) / vae.config.scaling_factor
542
+
543
+ pixels = vae.decode(latents).sample
544
+ pixels = pytorch2numpy(pixels, quant=False)
545
+
546
+ clear_memory()
547
+ return pixels, [fg, bg]
548
+
549
+
550
+ @torch.inference_mode()
551
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
552
+ input_fg, matting = run_rmbg(input_fg)
553
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
554
+ return input_fg, results
555
+
556
+
557
+
558
+ @torch.inference_mode()
559
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
560
+ bg_source = BGSource(bg_source)
561
+
562
+ # Convert numerical inputs to appropriate types
563
+ image_width = int(image_width)
564
+ image_height = int(image_height)
565
+ num_samples = int(num_samples)
566
+ seed = int(seed)
567
+ steps = int(steps)
568
+ cfg = float(cfg)
569
+ highres_scale = float(highres_scale)
570
+ highres_denoise = float(highres_denoise)
571
+
572
+ if bg_source == BGSource.UPLOAD:
573
+ pass
574
+ elif bg_source == BGSource.UPLOAD_FLIP:
575
+ input_bg = np.fliplr(input_bg)
576
+ elif bg_source == BGSource.GREY:
577
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
578
+ elif bg_source == BGSource.LEFT:
579
+ gradient = np.linspace(224, 32, image_width)
580
+ image = np.tile(gradient, (image_height, 1))
581
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
582
+ elif bg_source == BGSource.RIGHT:
583
+ gradient = np.linspace(32, 224, image_width)
584
+ image = np.tile(gradient, (image_height, 1))
585
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
586
+ elif bg_source == BGSource.TOP:
587
+ gradient = np.linspace(224, 32, image_height)[:, None]
588
+ image = np.tile(gradient, (1, image_width))
589
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
590
+ elif bg_source == BGSource.BOTTOM:
591
+ gradient = np.linspace(32, 224, image_height)[:, None]
592
+ image = np.tile(gradient, (1, image_width))
593
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
594
+ else:
595
+ raise ValueError('Wrong background source!')
596
+
597
+ input_fg, matting = run_rmbg(input_fg)
598
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
599
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
600
+ final_results = results + extra_images
601
+
602
+ # Save the generated images
603
+ save_images(results, prefix="relight")
604
+
605
+ return results
606
+
607
+
608
+ quick_prompts = [
609
+ 'sunshine from window',
610
+ 'neon light, city',
611
+ 'sunset over sea',
612
+ 'golden time',
613
+ 'sci-fi RGB glowing, cyberpunk',
614
+ 'natural lighting',
615
+ 'warm atmosphere, at home, bedroom',
616
+ 'magic lit',
617
+ 'evil, gothic, Yharnam',
618
+ 'light and shadow',
619
+ 'shadow from window',
620
+ 'soft studio lighting',
621
+ 'home atmosphere, cozy bedroom illumination',
622
+ 'neon, Wong Kar-wai, warm'
623
+ ]
624
+ quick_prompts = [[x] for x in quick_prompts]
625
+
626
+
627
+ quick_subjects = [
628
+ 'modern sofa, high quality leather',
629
+ 'elegant dining table, polished wood',
630
+ 'luxurious bed, premium mattress',
631
+ 'minimalist office desk, clean design',
632
+ 'vintage wooden cabinet, antique finish',
633
+ ]
634
+ quick_subjects = [[x] for x in quick_subjects]
635
+
636
+
637
+ class BGSource(Enum):
638
+ UPLOAD = "Use Background Image"
639
+ UPLOAD_FLIP = "Use Flipped Background Image"
640
+ LEFT = "Left Light"
641
+ RIGHT = "Right Light"
642
+ TOP = "Top Light"
643
+ BOTTOM = "Bottom Light"
644
+ GREY = "Ambient"
645
+
646
+ # Add save function
647
+ def save_images(images, prefix="relight"):
648
+ # Create output directory if it doesn't exist
649
+ output_dir = Path("outputs")
650
+ output_dir.mkdir(exist_ok=True)
651
+
652
+ # Create timestamp for unique filenames
653
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
654
+
655
+ saved_paths = []
656
+ for i, img in enumerate(images):
657
+ if isinstance(img, np.ndarray):
658
+ # Convert to PIL Image if numpy array
659
+ img = Image.fromarray(img)
660
+
661
+ # Create filename with timestamp
662
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
663
+ filepath = output_dir / filename
664
+
665
+ # Save image
666
+ img.save(filepath)
667
+
668
+
669
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
670
+ return saved_paths
671
+
672
+
673
+ class MaskMover:
674
+ def __init__(self):
675
+ self.extracted_fg = None
676
+ self.original_fg = None # Store original foreground
677
+
678
+ def set_extracted_fg(self, fg_image):
679
+ """Store the extracted foreground with alpha channel"""
680
+ if isinstance(fg_image, np.ndarray):
681
+ self.extracted_fg = fg_image.copy()
682
+ self.original_fg = fg_image.copy()
683
+ else:
684
+ self.extracted_fg = np.array(fg_image)
685
+ self.original_fg = np.array(fg_image)
686
+ return self.extracted_fg
687
+
688
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
689
+ """Create composite with foreground at specified position"""
690
+ if self.original_fg is None or background is None:
691
+ return background
692
+
693
+ # Convert inputs to PIL Images
694
+ if isinstance(background, np.ndarray):
695
+ bg = Image.fromarray(background).convert('RGBA')
696
+ else:
697
+ bg = background.convert('RGBA')
698
+
699
+ if isinstance(self.original_fg, np.ndarray):
700
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
701
+ else:
702
+ fg = self.original_fg.convert('RGBA')
703
+
704
+ # Scale the foreground size
705
+ new_width = int(fg.width * scale)
706
+ new_height = int(fg.height * scale)
707
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
708
+
709
+ # Center the scaled foreground at the position
710
+ x = int(x_pos - new_width / 2)
711
+ y = int(y_pos - new_height / 2)
712
+
713
+ # Create composite
714
+ result = bg.copy()
715
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
716
+
717
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
718
+
719
+ def get_depth(image):
720
+ if image is None:
721
+ return None
722
+ # Convert from PIL/gradio format to cv2
723
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
724
+ # Get depth map
725
+ depth = model.infer_image(raw_img) # HxW raw depth map
726
+ # Normalize depth for visualization
727
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
728
+ # Convert to RGB for display
729
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
730
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
731
+ return Image.fromarray(depth_colored)
732
+
733
+
734
+ from PIL import Image
735
+
736
+ def compress_image(image):
737
+ # Convert Gradio image (numpy array) to PIL Image
738
+ img = Image.fromarray(image)
739
+
740
+ # Resize image if dimensions are too large
741
+ max_size = 1024 # Maximum dimension size
742
+ if img.width > max_size or img.height > max_size:
743
+ ratio = min(max_size/img.width, max_size/img.height)
744
+ new_size = (int(img.width * ratio), int(img.height * ratio))
745
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
746
+
747
+ quality = 95 # Start with high quality
748
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
749
+
750
+ # Check file size and adjust quality if necessary
751
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
752
+ quality -= 5 # Decrease quality
753
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
754
+ if quality < 20: # Prevent quality from going too low
755
+ break
756
+
757
+ # Convert back to numpy array for Gradio
758
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
759
+ return compressed_img
760
+
761
+
762
+ block = gr.Blocks().queue()
763
+ with block:
764
+ with gr.Tab("Text"):
765
+ with gr.Row():
766
+ gr.Markdown("## Product Placement from Text")
767
+ with gr.Row():
768
+ with gr.Column():
769
+ with gr.Row():
770
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
771
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
772
+ with gr.Group():
773
+ prompt = gr.Textbox(label="Prompt")
774
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
775
+ value=BGSource.GREY.value,
776
+ label="Lighting Preference (Initial Latent)", type='value')
777
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
778
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
779
+ relight_button = gr.Button(value="Relight")
780
+
781
+ with gr.Group():
782
+ with gr.Row():
783
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
784
+ seed = gr.Number(label="Seed", value=12345, precision=0)
785
+
786
+ with gr.Row():
787
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
788
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
789
+
790
+ with gr.Accordion("Advanced options", open=False):
791
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
792
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
793
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
794
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
795
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
796
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
797
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
798
+ with gr.Column():
799
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
800
+ with gr.Row():
801
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
802
+ # gr.Examples(
803
+ # fn=lambda *args: ([args[-1]], None),
804
+ # examples=db_examples.foreground_conditioned_examples,
805
+ # inputs=[
806
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
807
+ # ],
808
+ # outputs=[result_gallery, output_bg],
809
+ # run_on_click=True, examples_per_page=1024
810
+ # )
811
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
812
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
813
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
814
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
815
+
816
+ with gr.Tab("Background", visible=False):
817
+ mask_mover = MaskMover()
818
+
819
+
820
+ with gr.Row():
821
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
822
+ gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
823
+
824
+ with gr.Row():
825
+ with gr.Column():
826
+ # Step 1: Input and Extract
827
+ with gr.Row():
828
+ with gr.Group():
829
+ gr.Markdown("### Step 1: Extract Foreground")
830
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
831
+ # find_objects_button = gr.Button(value="Find Objects")
832
+ extract_button = gr.Button(value="Remove Background")
833
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
834
+
835
+ with gr.Row():
836
+ # Step 2: Background and Position
837
+ with gr.Group():
838
+ gr.Markdown("### Step 2: Position on Background")
839
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
840
+
841
+ with gr.Row():
842
+ x_slider = gr.Slider(
843
+ minimum=0,
844
+ maximum=1000,
845
+ label="X Position",
846
+ value=500,
847
+ visible=False
848
+ )
849
+ y_slider = gr.Slider(
850
+ minimum=0,
851
+ maximum=1000,
852
+ label="Y Position",
853
+ value=500,
854
+ visible=False
855
+ )
856
+ fg_scale_slider = gr.Slider(
857
+ label="Foreground Scale",
858
+ minimum=0.01,
859
+ maximum=3.0,
860
+ value=1.0,
861
+ step=0.01
862
+ )
863
+
864
+ editor = gr.ImageEditor(
865
+ type="numpy",
866
+ label="Position Foreground",
867
+ height=480,
868
+ visible=False
869
+ )
870
+ get_depth_button = gr.Button(value="Get Depth")
871
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
872
+
873
+ # Step 3: Relighting Options
874
+ with gr.Group():
875
+ gr.Markdown("### Step 3: Relighting Settings")
876
+ prompt = gr.Textbox(label="Prompt")
877
+ bg_source = gr.Radio(
878
+ choices=[e.value for e in BGSource],
879
+ value=BGSource.UPLOAD.value,
880
+ label="Background Source",
881
+ type='value'
882
+ )
883
+
884
+ example_prompts = gr.Dataset(
885
+ samples=quick_prompts,
886
+ label='Prompt Quick List',
887
+ components=[prompt]
888
+ )
889
+ # bg_gallery = gr.Gallery(
890
+ # height=450,
891
+ # label='Background Quick List',
892
+ # value=db_examples.bg_samples,
893
+ # columns=5,
894
+ # allow_preview=False
895
+ # )
896
+ relight_button_bg = gr.Button(value="Relight")
897
+
898
+ # Additional settings
899
+ with gr.Group():
900
+ with gr.Row():
901
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
902
+ seed = gr.Number(label="Seed", value=12345, precision=0)
903
+ with gr.Row():
904
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
905
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
906
+
907
+ with gr.Accordion("Advanced options", open=False):
908
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
909
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
910
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
911
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
912
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
913
+ n_prompt = gr.Textbox(
914
+ label="Negative Prompt",
915
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
916
+ )
917
+
918
+ with gr.Column():
919
+ result_gallery = gr.Image(height=832, label='Outputs')
920
+
921
+ def extract_foreground(image):
922
+ if image is None:
923
+ return None, gr.update(visible=True), gr.update(visible=True)
924
+ result, rgba = run_rmbg(image)
925
+ mask_mover.set_extracted_fg(rgba)
926
+
927
+ return result, gr.update(visible=True), gr.update(visible=True)
928
+
929
+
930
+ original_bg = None
931
+
932
+ extract_button.click(
933
+ fn=extract_foreground,
934
+ inputs=[input_image],
935
+ outputs=[extracted_fg, x_slider, y_slider]
936
+ )
937
+
938
+ # find_objects_button.click(
939
+ # fn=find_objects,
940
+ # inputs=[input_image],
941
+ # outputs=[extracted_fg]
942
+ # )
943
+
944
+ get_depth_button.click(
945
+ fn=get_depth,
946
+ inputs=[input_bg],
947
+ outputs=[depth_image]
948
+ )
949
+
950
+ # def update_position(background, x_pos, y_pos, scale):
951
+ # """Update composite when position changes"""
952
+ # global original_bg
953
+ # if background is None:
954
+ # return None
955
+
956
+ # if original_bg is None:
957
+ # original_bg = background.copy()
958
+
959
+ # # Convert string values to float
960
+ # x_pos = float(x_pos)
961
+ # y_pos = float(y_pos)
962
+ # scale = float(scale)
963
+
964
+ # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
965
+
966
+ class BackgroundManager:
967
+ def __init__(self):
968
+ self.original_bg = None
969
+
970
+ def update_position(self, background, x_pos, y_pos, scale):
971
+ """Update composite when position changes"""
972
+ if background is None:
973
+ return None
974
+
975
+ if self.original_bg is None:
976
+ self.original_bg = background.copy()
977
+
978
+ # Convert string values to float
979
+ x_pos = float(x_pos)
980
+ y_pos = float(y_pos)
981
+ scale = float(scale)
982
+
983
+ return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
984
+
985
+ # Create an instance of BackgroundManager
986
+ bg_manager = BackgroundManager()
987
+
988
+
989
+ x_slider.change(
990
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
991
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
992
+ outputs=[input_bg]
993
+ )
994
+
995
+ y_slider.change(
996
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
997
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
998
+ outputs=[input_bg]
999
+ )
1000
+
1001
+ fg_scale_slider.change(
1002
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1003
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1004
+ outputs=[input_bg]
1005
+ )
1006
+
1007
+ # Update inputs list to include fg_scale_slider
1008
+
1009
+ def process_relight_with_position(*args):
1010
+ if mask_mover.extracted_fg is None:
1011
+ gr.Warning("Please extract foreground first")
1012
+ return None
1013
+
1014
+ background = args[1] # Get background image
1015
+ x_pos = float(args[-3]) # x_slider value
1016
+ y_pos = float(args[-2]) # y_slider value
1017
+ scale = float(args[-1]) # fg_scale_slider value
1018
+
1019
+ # Get original foreground size after scaling
1020
+ fg = Image.fromarray(mask_mover.original_fg)
1021
+ new_width = int(fg.width * scale)
1022
+ new_height = int(fg.height * scale)
1023
+
1024
+ # Calculate crop region around foreground position
1025
+ crop_x = int(x_pos - new_width/2)
1026
+ crop_y = int(y_pos - new_height/2)
1027
+ crop_width = new_width
1028
+ crop_height = new_height
1029
+
1030
+ # Add padding for context (20% extra on each side)
1031
+ padding = 0.2
1032
+ crop_x = int(crop_x - crop_width * padding)
1033
+ crop_y = int(crop_y - crop_height * padding)
1034
+ crop_width = int(crop_width * (1 + 2 * padding))
1035
+ crop_height = int(crop_height * (1 + 2 * padding))
1036
+
1037
+ # Ensure crop dimensions are multiples of 8
1038
+ crop_width = ((crop_width + 7) // 8) * 8
1039
+ crop_height = ((crop_height + 7) // 8) * 8
1040
+
1041
+ # Ensure crop region is within image bounds
1042
+ bg_height, bg_width = background.shape[:2]
1043
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
1044
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
1045
+
1046
+ # Get actual crop dimensions after boundary check
1047
+ crop_width = min(crop_width, bg_width - crop_x)
1048
+ crop_height = min(crop_height, bg_height - crop_y)
1049
+
1050
+ # Ensure dimensions are multiples of 8 again
1051
+ crop_width = (crop_width // 8) * 8
1052
+ crop_height = (crop_height // 8) * 8
1053
+
1054
+ # Crop region from background
1055
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1056
+
1057
+ # Create composite in cropped region
1058
+ fg_local_x = int(new_width/2 + crop_width*padding)
1059
+ fg_local_y = int(new_height/2 + crop_height*padding)
1060
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1061
+
1062
+ # Process the cropped region
1063
+ crop_args = list(args)
1064
+ crop_args[0] = cropped_composite
1065
+ crop_args[1] = crop_region
1066
+ crop_args[3] = crop_width
1067
+ crop_args[4] = crop_height
1068
+ crop_args = crop_args[:-3] # Remove position and scale arguments
1069
+
1070
+ # Get relit result
1071
+ relit_crop = process_relight_bg(*crop_args)[0]
1072
+
1073
+ # Resize relit result to match crop dimensions if needed
1074
+ if relit_crop.shape[:2] != (crop_height, crop_width):
1075
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1076
+
1077
+ # Place relit crop back into original background
1078
+ result = background.copy()
1079
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1080
+
1081
+ return result
1082
+
1083
+ ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1084
+
1085
+ # Update button click events with new inputs list
1086
+ relight_button_bg.click(
1087
+ fn=process_relight_with_position,
1088
+ inputs=ips_bg,
1089
+ outputs=[result_gallery]
1090
+ )
1091
+
1092
+
1093
+ example_prompts.click(
1094
+ fn=lambda x: x[0],
1095
+ inputs=example_prompts,
1096
+ outputs=prompt,
1097
+ show_progress=False,
1098
+ queue=False
1099
+ )
1100
+
1101
+
1102
+
1103
+ block.launch(server_name='0.0.0.0', share=True)