zhiweili commited on
Commit
6baa93c
·
1 Parent(s): f833ac0

add app_face

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ .DS_Store
3
+ __pycache__
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  ---
12
 
13
+ Modified from: https://huggingface.co/spaces/turboedit/turbo_edit
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_base import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch()
app_base.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image,
9
+ restore_result,
10
+ )
11
+ from enhance_utils import enhance_image
12
+
13
+ DEFAULT_SRC_PROMPT = "a woman, photo"
14
+ DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
15
+
16
+ DEFAULT_CATEGORY = "face"
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ def create_demo() -> gr.Blocks:
21
+ from inversion_run_base import run as base_run
22
+
23
+ @spaces.GPU(duration=10)
24
+ def image_to_image(
25
+ input_image: Image,
26
+ input_image_prompt: str,
27
+ edit_prompt: str,
28
+ seed: int,
29
+ w1: float,
30
+ num_steps: int,
31
+ start_step: int,
32
+ guidance_scale: float,
33
+ generate_size: int,
34
+ adapter_weights: float,
35
+ enhance_face: bool = True,
36
+ ):
37
+ w2 = 1.0
38
+ run_task_time = 0
39
+ time_cost_str = ''
40
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
41
+ run_model = base_run
42
+ res_image = run_model(
43
+ input_image,
44
+ input_image_prompt,
45
+ edit_prompt,
46
+ generate_size,
47
+ seed,
48
+ w1,
49
+ w2,
50
+ num_steps,
51
+ start_step,
52
+ guidance_scale,
53
+ adapter_weights,
54
+ )
55
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
56
+ enhanced_image = enhance_image(res_image, enhance_face)
57
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
58
+
59
+ return enhanced_image, res_image, time_cost_str
60
+
61
+ def get_time_cost(run_task_time, time_cost_str):
62
+ now_time = int(time.time()*1000)
63
+ if run_task_time == 0:
64
+ time_cost_str = 'start'
65
+ else:
66
+ if time_cost_str != '':
67
+ time_cost_str += f'-->'
68
+ time_cost_str += f'{now_time - run_task_time}'
69
+ run_task_time = now_time
70
+ return run_task_time, time_cost_str
71
+
72
+ with gr.Blocks() as demo:
73
+ croper = gr.State()
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
77
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
78
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
79
+ with gr.Column():
80
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
81
+ start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
82
+ with gr.Accordion("Advanced Options", open=False):
83
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
84
+ generate_size = gr.Number(label="Generate Size", value=768)
85
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
86
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
87
+ enhance_face = gr.Checkbox(label="Enhance Face", value=False)
88
+ adapter_weights = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Adapter Weights", visible=False)
89
+ with gr.Column():
90
+ seed = gr.Number(label="Seed", value=8)
91
+ w1 = gr.Number(label="W1", value=2.5)
92
+ g_btn = gr.Button("Edit Image")
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ input_image = gr.Image(label="Input Image", type="pil")
97
+ with gr.Column():
98
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
99
+ download_path = gr.File(label="Download the output image", interactive=False)
100
+ with gr.Column():
101
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
102
+ enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
103
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
104
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
105
+
106
+ g_btn.click(
107
+ fn=segment_image,
108
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
109
+ outputs=[origin_area_image, croper],
110
+ ).success(
111
+ fn=image_to_image,
112
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, adapter_weights, enhance_face],
113
+ outputs=[enhanced_image, generated_image, generated_cost],
114
+ ).success(
115
+ fn=restore_result,
116
+ inputs=[croper, category, enhanced_image],
117
+ outputs=[restored_image, download_path],
118
+ )
119
+
120
+ return demo
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_collections import config_dict
2
+ import yaml
3
+ from diffusers.schedulers import (
4
+ DDIMScheduler,
5
+ EulerAncestralDiscreteScheduler,
6
+ EulerDiscreteScheduler,
7
+ DDPMScheduler,
8
+ )
9
+ from inversion_utils import (
10
+ deterministic_ddim_step,
11
+ deterministic_ddpm_step,
12
+ deterministic_euler_step,
13
+ deterministic_non_ancestral_euler_step,
14
+ )
15
+
16
+ BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
17
+ SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
18
+ MODELS = [
19
+ "stabilityai/sdxl-turbo",
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ "CompVis/stable-diffusion-v1-4",
22
+ ]
23
+
24
+ def get_num_steps_actual(cfg):
25
+ return (
26
+ cfg.num_steps_inversion
27
+ - cfg.step_start
28
+ + (1 if cfg.clean_step_timestep > 0 else 0)
29
+ if cfg.timesteps is None
30
+ else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
31
+ )
32
+
33
+
34
+ def get_config(args):
35
+ if args.config_from_file and args.config_from_file != "":
36
+ with open(args.config_from_file, "r") as f:
37
+ cfg = config_dict.ConfigDict(yaml.safe_load(f))
38
+
39
+ num_steps_actual = get_num_steps_actual(cfg)
40
+
41
+ else:
42
+ cfg = config_dict.ConfigDict()
43
+
44
+ cfg.seed = 2
45
+ cfg.self_r = 0.5
46
+ cfg.cross_r = 0.9
47
+ cfg.eta = 1
48
+ cfg.scheduler_type = SCHEDULERS[0]
49
+
50
+ cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
51
+ cfg.step_start = 20
52
+ cfg.timesteps = None
53
+ cfg.noise_timesteps = None
54
+ num_steps_actual = get_num_steps_actual(cfg)
55
+ cfg.ws1 = [2] * num_steps_actual
56
+ cfg.ws2 = [1] * num_steps_actual
57
+ cfg.real_cfg_scale = 0
58
+ cfg.real_cfg_scale_save = 0
59
+ cfg.breakdown = BREAKDOWNS[1]
60
+ cfg.noise_shift_delta = 1
61
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
62
+
63
+ cfg.clean_step_timestep = 0
64
+
65
+ cfg.model = MODELS[1]
66
+
67
+ if cfg.scheduler_type == "ddim":
68
+ cfg.scheduler_class = DDIMScheduler
69
+ cfg.step_function = deterministic_ddim_step
70
+ elif cfg.scheduler_type == "ddpm":
71
+ cfg.scheduler_class = DDPMScheduler
72
+ cfg.step_function = deterministic_ddpm_step
73
+ elif cfg.scheduler_type == "euler":
74
+ cfg.scheduler_class = EulerAncestralDiscreteScheduler
75
+ cfg.step_function = deterministic_euler_step
76
+ elif cfg.scheduler_type == "euler_non_ancestral":
77
+ cfg.scheduler_class = EulerDiscreteScheduler
78
+ cfg.step_function = deterministic_non_ancestral_euler_step
79
+ else:
80
+ raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
81
+
82
+ with cfg.ignore_type():
83
+ if isinstance(cfg.max_norm_zs, (int, float)):
84
+ cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
85
+
86
+ if isinstance(cfg.ws1, (int, float)):
87
+ cfg.ws1 = [cfg.ws1] * num_steps_actual
88
+
89
+ if isinstance(cfg.ws2, (int, float)):
90
+ cfg.ws2 = [cfg.ws2] * num_steps_actual
91
+
92
+ if not hasattr(cfg, "update_eta"):
93
+ cfg.update_eta = False
94
+
95
+ if not hasattr(cfg, "save_timesteps"):
96
+ cfg.save_timesteps = None
97
+
98
+ if not hasattr(cfg, "scheduler_timesteps"):
99
+ cfg.scheduler_timesteps = None
100
+
101
+ assert (
102
+ cfg.scheduler_type == "ddpm" or cfg.timesteps is None
103
+ ), "timesteps must be None for ddim/euler"
104
+
105
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
106
+ assert (
107
+ len(cfg.max_norm_zs) == num_steps_actual
108
+ ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
109
+
110
+ assert (
111
+ len(cfg.ws1) == num_steps_actual
112
+ ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
113
+
114
+ assert (
115
+ len(cfg.ws2) == num_steps_actual
116
+ ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
117
+
118
+ assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
119
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
120
+ ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
121
+
122
+ assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
123
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
124
+ ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
125
+
126
+ return cfg
127
+
128
+
129
+ def get_config_name(config, args):
130
+ if args.folder_name is not None and args.folder_name != "":
131
+ return args.folder_name
132
+ timesteps_str = (
133
+ f"step_start {config.step_start}"
134
+ if config.timesteps is None
135
+ else f"timesteps {config.timesteps}"
136
+ )
137
+ return f"""\
138
+ ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
139
+ real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
140
+ scheduler_type {config.scheduler_type} fp16 {args.fp16}\
141
+ """
croper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ mask_size: int = 256,
12
+ mask_expansion: int = 20,
13
+ ):
14
+ self.input_image = input_image
15
+ self.target_mask = target_mask
16
+ self.mask_size = mask_size
17
+ self.mask_expansion = mask_expansion
18
+
19
+ def corp_mask_image(self):
20
+ target_mask = self.target_mask
21
+ input_image = self.input_image
22
+ mask_expansion = self.mask_expansion
23
+ original_width, original_height = input_image.size
24
+ mask_indices = np.where(target_mask)
25
+ start_y = np.min(mask_indices[0])
26
+ end_y = np.max(mask_indices[0])
27
+ start_x = np.min(mask_indices[1])
28
+ end_x = np.max(mask_indices[1])
29
+ mask_height = end_y - start_y
30
+ mask_width = end_x - start_x
31
+ # choose the max side length
32
+ max_side_length = max(mask_height, mask_width)
33
+ # expand the mask area
34
+ height_diff = (max_side_length - mask_height) // 2
35
+ width_diff = (max_side_length - mask_width) // 2
36
+ start_y = start_y - mask_expansion - height_diff
37
+ if start_y < 0:
38
+ start_y = 0
39
+ end_y = end_y + mask_expansion + height_diff
40
+ if end_y > original_height:
41
+ end_y = original_height
42
+ start_x = start_x - mask_expansion - width_diff
43
+ if start_x < 0:
44
+ start_x = 0
45
+ end_x = end_x + mask_expansion + width_diff
46
+ if end_x > original_width:
47
+ end_x = original_width
48
+ expanded_height = end_y - start_y
49
+ expanded_width = end_x - start_x
50
+ expanded_max_side_length = max(expanded_height, expanded_width)
51
+ # calculate the crop area
52
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
+ crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
+ crop_mask_end_y = crop_mask_start_y + expanded_height
55
+ crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
+ crop_mask_end_x = crop_mask_start_x + expanded_width
57
+ # create a square mask
58
+ square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
+
62
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
+ square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
+
66
+ self.origin_start_x = start_x
67
+ self.origin_start_y = start_y
68
+ self.origin_end_x = end_x
69
+ self.origin_end_y = end_y
70
+
71
+ self.square_start_x = crop_mask_start_x
72
+ self.square_start_y = crop_mask_start_y
73
+ self.square_end_x = crop_mask_end_x
74
+ self.square_end_y = crop_mask_end_y
75
+
76
+ self.square_length = expanded_max_side_length
77
+ self.square_mask_image = square_mask_image
78
+ self.square_image = square_image
79
+ self.corp_mask = crop_mask
80
+
81
+ mask_size = self.mask_size
82
+ self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
+ self.resized_square_image = square_image.resize((mask_size, mask_size))
84
+
85
+ return self.resized_square_mask_image
86
+
87
+ def restore_result(self, generated_image):
88
+ square_length = self.square_length
89
+ generated_image = generated_image.resize((square_length, square_length))
90
+ square_mask_image = self.square_mask_image
91
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
+ cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
+
94
+ restored_image = self.input_image.copy()
95
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
+
97
+ return restored_image
98
+
99
+ def restore_result_v2(self, generated_image):
100
+ square_length = self.square_length
101
+ generated_image = generated_image.resize((square_length, square_length))
102
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
+
104
+ restored_image = self.input_image.copy()
105
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
+
107
+ return restored_image
108
+
enhance_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from gfpgan.utils import GFPGANer
8
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
9
+ from realesrgan.utils import RealESRGANer
10
+
11
+ os.system("pip freeze")
12
+ if not os.path.exists('GFPGANv1.4.pth'):
13
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
14
+ if not os.path.exists('realesr-general-x4v3.pth'):
15
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
16
+
17
+ os.makedirs('output', exist_ok=True)
18
+
19
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
20
+ model_path = 'realesr-general-x4v3.pth'
21
+ half = True if torch.cuda.is_available() else False
22
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
23
+
24
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
25
+
26
+ def enhance_image(
27
+ pil_image: Image,
28
+ enhance_face: bool = True,
29
+ ):
30
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
31
+
32
+ h, w = img.shape[0:2]
33
+ if h < 300:
34
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
35
+ if enhance_face:
36
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
37
+ else:
38
+ output, _ = upsampler.enhance(img, outscale=2)
39
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
40
+
41
+ return pil_output
inversion_run_base.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ from diffusers import (
5
+ DDPMScheduler,
6
+ StableDiffusionXLImg2ImgPipeline,
7
+ )
8
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
9
+ from PIL import Image
10
+ from inversion_utils import get_ddpm_inversion_scheduler, create_xts
11
+ from config import get_config, get_num_steps_actual
12
+ from functools import partial
13
+ from compel import Compel, ReturnedEmbeddingsType
14
+
15
+ os.system("pip freeze")
16
+ if not os.path.exists('GFPGANv1.4.pth'):
17
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
18
+ if not os.path.exists('realesr-general-x4v3.pth'):
19
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
20
+
21
+ os.makedirs('output', exist_ok=True)
22
+
23
+ class Object(object):
24
+ pass
25
+
26
+ args = Object()
27
+ args.images_paths = None
28
+ args.images_folder = None
29
+ args.force_use_cpu = False
30
+ args.folder_name = 'test_measure_time'
31
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
32
+ args.save_intermediate_results = False
33
+ args.batch_size = None
34
+ args.skip_p_to_p = True
35
+ args.only_p_to_p = False
36
+ args.fp16 = False
37
+ args.prompts_file = 'dataset_measure_time/dataset.json'
38
+ args.images_in_prompts_file = None
39
+ args.seed = 986
40
+ args.time_measure_n = 1
41
+
42
+
43
+ assert (
44
+ args.batch_size is None or args.save_intermediate_results is False
45
+ ), "save_intermediate_results is not implemented for batch_size > 1"
46
+
47
+ generator = None
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+
50
+ # BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
51
+ BASE_MODEL = "stabilityai/sdxl-turbo"
52
+
53
+
54
+ pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
55
+ BASE_MODEL,
56
+ torch_dtype=torch.float16,
57
+ variant="fp16",
58
+ use_safetensors=True,
59
+ )
60
+ pipeline = pipeline.to(device)
61
+
62
+ pipeline.scheduler = DDPMScheduler.from_pretrained(
63
+ BASE_MODEL,
64
+ subfolder="scheduler",
65
+ )
66
+
67
+ config = get_config(args)
68
+
69
+ compel_proc = Compel(
70
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
71
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
72
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
73
+ requires_pooled=[False, True]
74
+ )
75
+
76
+ def run(
77
+ input_image:Image,
78
+ src_prompt:str,
79
+ tgt_prompt:str,
80
+ generate_size:int,
81
+ seed:int,
82
+ w1:float,
83
+ w2:float,
84
+ num_steps:int,
85
+ start_step:int,
86
+ guidance_scale:float,
87
+ adapter_weights:float,
88
+ ):
89
+ generator = torch.Generator().manual_seed(seed)
90
+
91
+ config.num_steps_inversion = num_steps
92
+ config.step_start = start_step
93
+ num_steps_actual = get_num_steps_actual(config)
94
+
95
+
96
+ num_steps_inversion = config.num_steps_inversion
97
+ denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
98
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
99
+
100
+ timesteps, num_inference_steps = retrieve_timesteps(
101
+ pipeline.scheduler, num_steps_inversion, device, None
102
+ )
103
+ timesteps, num_inference_steps = pipeline.get_timesteps(
104
+ num_inference_steps=num_inference_steps,
105
+ denoising_start=denoising_start,
106
+ strength=0,
107
+ device=device,
108
+ )
109
+ timesteps = timesteps.type(torch.int64)
110
+
111
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
112
+ timesteps_len = len(timesteps)
113
+ config.step_start = start_step + num_steps_actual - timesteps_len
114
+ num_steps_actual = timesteps_len
115
+ config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
116
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
117
+ print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
118
+ pipeline.__call__ = partial(
119
+ pipeline.__call__,
120
+ num_inference_steps=num_steps_inversion,
121
+ guidance_scale=guidance_scale,
122
+ generator=generator,
123
+ denoising_start=denoising_start,
124
+ strength=0,
125
+ )
126
+
127
+ x_0_image = input_image
128
+ x_0 = encode_image(x_0_image, pipeline)
129
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
130
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
131
+ latents = [x_ts[0]]
132
+ x_ts_c_hat = [None]
133
+ config.ws1 = [w1] * num_steps_actual
134
+ config.ws2 = [w2] * num_steps_actual
135
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
136
+ pipeline.scheduler,
137
+ config.step_function,
138
+ config,
139
+ timesteps,
140
+ config.save_timesteps,
141
+ latents,
142
+ x_ts,
143
+ x_ts_c_hat,
144
+ args.save_intermediate_results,
145
+ pipeline,
146
+ x_0,
147
+ v1s_images := [],
148
+ v2s_images := [],
149
+ deltas_images := [],
150
+ v1_x0s := [],
151
+ v2_x0s := [],
152
+ deltas_x0s := [],
153
+ "res12",
154
+ image_name="im_name",
155
+ time_measure_n=args.time_measure_n,
156
+ )
157
+ latent = latents[0].expand(3, -1, -1, -1)
158
+ prompt = [src_prompt, src_prompt, tgt_prompt]
159
+ conditioning, pooled = compel_proc(prompt)
160
+ image = pipeline.__call__(
161
+ image=latent,
162
+ prompt_embeds=conditioning,
163
+ pooled_prompt_embeds=pooled,
164
+ eta=1,
165
+ ).images
166
+ return image[2]
167
+
168
+ def encode_image(image, pipe):
169
+ image = pipe.image_processor.preprocess(image)
170
+ originDtype = pipe.dtype
171
+ image = image.to(device=device, dtype=originDtype)
172
+
173
+ if pipe.vae.config.force_upcast:
174
+ image = image.float()
175
+ pipe.vae.to(dtype=torch.float32)
176
+
177
+ if isinstance(generator, list):
178
+ init_latents = [
179
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
180
+ for i in range(1)
181
+ ]
182
+ init_latents = torch.cat(init_latents, dim=0)
183
+ else:
184
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
185
+
186
+ if pipe.vae.config.force_upcast:
187
+ pipe.vae.to(originDtype)
188
+
189
+ init_latents = init_latents.to(originDtype)
190
+ init_latents = pipe.vae.config.scaling_factor * init_latents
191
+
192
+ return init_latents.to(dtype=torch.float16)
193
+
194
+ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
195
+ # get the original timestep using init_timestep
196
+ if denoising_start is None:
197
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
198
+ t_start = max(num_inference_steps - init_timestep, 0)
199
+ else:
200
+ t_start = 0
201
+
202
+ timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
203
+
204
+ # Strength is irrelevant if we directly request a timestep to start at;
205
+ # that is, strength is determined by the denoising_start instead.
206
+ if denoising_start is not None:
207
+ discrete_timestep_cutoff = int(
208
+ round(
209
+ pipe.scheduler.config.num_train_timesteps
210
+ - (denoising_start * pipe.scheduler.config.num_train_timesteps)
211
+ )
212
+ )
213
+
214
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
215
+ if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
216
+ # if the scheduler is a 2nd order scheduler we might have to do +1
217
+ # because `num_inference_steps` might be even given that every timestep
218
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
219
+ # mean that we cut the timesteps in the middle of the denoising step
220
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
221
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
222
+ num_inference_steps = num_inference_steps + 1
223
+
224
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
225
+ timesteps = timesteps[-num_inference_steps:]
226
+ return timesteps, num_inference_steps
227
+
228
+ return timesteps, num_inference_steps - t_start
inversion_utils.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import PIL
4
+
5
+ from typing import List, Optional, Union
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from PIL import Image
8
+ from diffusers.utils import logging
9
+
10
+ VECTOR_DATA_FOLDER = "vector_data"
11
+ VECTOR_DATA_DICT = "vector_data"
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ def get_ddpm_inversion_scheduler(
16
+ scheduler,
17
+ step_function,
18
+ config,
19
+ timesteps,
20
+ save_timesteps,
21
+ latents,
22
+ x_ts,
23
+ x_ts_c_hat,
24
+ save_intermediate_results,
25
+ pipe,
26
+ x_0,
27
+ v1s_images,
28
+ v2s_images,
29
+ deltas_images,
30
+ v1_x0s,
31
+ v2_x0s,
32
+ deltas_x0s,
33
+ folder_name,
34
+ image_name,
35
+ time_measure_n,
36
+ ):
37
+ def step(
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ variance_noise: Optional[torch.FloatTensor] = None,
45
+ return_dict: bool = True,
46
+ ):
47
+ # if scheduler.is_save:
48
+ # start = timer()
49
+ res_inv = step_save_latents(
50
+ scheduler,
51
+ model_output[:1, :, :, :],
52
+ timestep,
53
+ sample[:1, :, :, :],
54
+ eta,
55
+ use_clipped_model_output,
56
+ generator,
57
+ variance_noise,
58
+ return_dict,
59
+ )
60
+ # end = timer()
61
+ # print(f"Run Time Inv: {end - start}")
62
+
63
+ res_inf = step_use_latents(
64
+ scheduler,
65
+ model_output[1:, :, :, :],
66
+ timestep,
67
+ sample[1:, :, :, :],
68
+ eta,
69
+ use_clipped_model_output,
70
+ generator,
71
+ variance_noise,
72
+ return_dict,
73
+ )
74
+ # res = res_inv
75
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
76
+ return res
77
+ # return res
78
+
79
+ scheduler.step_function = step_function
80
+ scheduler.is_save = True
81
+ scheduler._timesteps = timesteps
82
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
83
+ scheduler._config = config
84
+ scheduler.latents = latents
85
+ scheduler.x_ts = x_ts
86
+ scheduler.x_ts_c_hat = x_ts_c_hat
87
+ scheduler.step = step
88
+ scheduler.save_intermediate_results = save_intermediate_results
89
+ scheduler.pipe = pipe
90
+ scheduler.v1s_images = v1s_images
91
+ scheduler.v2s_images = v2s_images
92
+ scheduler.deltas_images = deltas_images
93
+ scheduler.v1_x0s = v1_x0s
94
+ scheduler.v2_x0s = v2_x0s
95
+ scheduler.deltas_x0s = deltas_x0s
96
+ scheduler.clean_step_run = False
97
+ scheduler.x_0s = create_xts(
98
+ config.noise_shift_delta,
99
+ config.noise_timesteps,
100
+ config.clean_step_timestep,
101
+ None,
102
+ pipe.scheduler,
103
+ timesteps,
104
+ x_0,
105
+ no_add_noise=True,
106
+ )
107
+ scheduler.folder_name = folder_name
108
+ scheduler.image_name = image_name
109
+ scheduler.p_to_p = False
110
+ scheduler.p_to_p_replace = False
111
+ scheduler.time_measure_n = time_measure_n
112
+ return scheduler
113
+
114
+ def step_save_latents(
115
+ self,
116
+ model_output: torch.FloatTensor,
117
+ timestep: int,
118
+ sample: torch.FloatTensor,
119
+ eta: float = 0.0,
120
+ use_clipped_model_output: bool = False,
121
+ generator=None,
122
+ variance_noise: Optional[torch.FloatTensor] = None,
123
+ return_dict: bool = True,
124
+ ):
125
+ # print(self._save_timesteps)
126
+ # timestep_index = map_timpstep_to_index[timestep]
127
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
128
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
129
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
130
+ u_hat_t = self.step_function(
131
+ model_output=model_output,
132
+ timestep=timestep,
133
+ sample=sample,
134
+ eta=eta,
135
+ use_clipped_model_output=use_clipped_model_output,
136
+ generator=generator,
137
+ variance_noise=variance_noise,
138
+ return_dict=False,
139
+ scheduler=self,
140
+ )
141
+
142
+ x_t_minus_1 = self.x_ts[next_timestep_index]
143
+ self.x_ts_c_hat.append(u_hat_t)
144
+
145
+ z_t = x_t_minus_1 - u_hat_t
146
+ self.latents.append(z_t)
147
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
148
+
149
+ x_t_minus_1_predicted = u_hat_t + z_t
150
+
151
+ if not return_dict:
152
+ return (x_t_minus_1_predicted,)
153
+
154
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
155
+
156
+ def step_use_latents(
157
+ self,
158
+ model_output: torch.FloatTensor,
159
+ timestep: int,
160
+ sample: torch.FloatTensor,
161
+ eta: float = 0.0,
162
+ use_clipped_model_output: bool = False,
163
+ generator=None,
164
+ variance_noise: Optional[torch.FloatTensor] = None,
165
+ return_dict: bool = True,
166
+ ):
167
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
168
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
169
+ next_timestep_index = (
170
+ timestep_index + 1 if not self.clean_step_run else -1
171
+ )
172
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
173
+
174
+ _, normalize_coefficient = normalize(
175
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
176
+ timestep_index,
177
+ self._config.max_norm_zs,
178
+ )
179
+
180
+ if normalize_coefficient == 0:
181
+ eta = 0
182
+
183
+ # eta = normalize_coefficient
184
+
185
+ x_t_hat_c_hat = self.step_function(
186
+ model_output=model_output,
187
+ timestep=timestep,
188
+ sample=sample,
189
+ eta=eta,
190
+ use_clipped_model_output=use_clipped_model_output,
191
+ generator=generator,
192
+ variance_noise=variance_noise,
193
+ return_dict=False,
194
+ scheduler=self,
195
+ )
196
+
197
+ w1 = self._config.ws1[timestep_index]
198
+ w2 = self._config.ws2[timestep_index]
199
+
200
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
201
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
202
+
203
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
204
+ if self._config.breakdown == "x_t_c_hat":
205
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
206
+
207
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
208
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
209
+
210
+ # if self._config.breakdown == "x_t_c_hat":
211
+ # v1 = x_t_hat_c_hat - x_t_c_hat
212
+ # v2 = x_t_c_hat - x_t_c
213
+ if (
214
+ self._config.breakdown == "x_t_hat_c"
215
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
216
+ ):
217
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
218
+ edit_prompts_num = (
219
+ (model_output.size(0) - zero_index_reconstruction) // 3
220
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
221
+ else (model_output.size(0) - zero_index_reconstruction) // 2
222
+ )
223
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
224
+ edit_images_indices = (
225
+ edit_prompts_num + zero_index_reconstruction,
226
+ (
227
+ model_output.size(0)
228
+ if self._config.breakdown == "x_t_hat_c"
229
+ else zero_index_reconstruction + 2 * edit_prompts_num
230
+ ),
231
+ )
232
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
233
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
234
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
235
+ ]
236
+ v1 = x_t_hat_c_hat - x_t_hat_c
237
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
238
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
239
+ path = os.path.join(
240
+ self.folder_name,
241
+ VECTOR_DATA_FOLDER,
242
+ self.image_name,
243
+ )
244
+ if not hasattr(self, VECTOR_DATA_DICT):
245
+ os.makedirs(path, exist_ok=True)
246
+ self.vector_data = dict()
247
+
248
+ x_t_0 = x_t_c_hat[1]
249
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
250
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
251
+
252
+ self.vector_data[timestep.item()] = dict()
253
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
254
+ edit_images_indices[0] : edit_images_indices[1]
255
+ ]
256
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
257
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
258
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
259
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
260
+ edit_images_indices[0] : edit_images_indices[1]
261
+ ]
262
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
263
+ 0
264
+ ].expand_as(x_t_hat_0)
265
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
266
+ next_timestep_index
267
+ ].expand_as(x_t_hat_0)
268
+
269
+ else: # no breakdown
270
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
271
+ v2 = 0
272
+
273
+ if self.save_intermediate_results and not self.p_to_p:
274
+ delta = v1 + v2
275
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
276
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
277
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
278
+
279
+ v1_images = decode_latents(v1, self.pipe)
280
+ self.v1s_images.append(v1_images)
281
+ v2_images = (
282
+ decode_latents(v2, self.pipe)
283
+ if self._config.breakdown != "no_breakdown"
284
+ else [PIL.Image.new("RGB", (1, 1))]
285
+ )
286
+ self.v2s_images.append(v2_images)
287
+ delta_images = decode_latents(delta, self.pipe)
288
+ self.deltas_images.append(delta_images)
289
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
290
+ self.v1_x0s.append(v1_plus_x0_images)
291
+ v2_plus_x0_images = (
292
+ decode_latents(v2_plus_x0, self.pipe)
293
+ if self._config.breakdown != "no_breakdown"
294
+ else [PIL.Image.new("RGB", (1, 1))]
295
+ )
296
+ self.v2_x0s.append(v2_plus_x0_images)
297
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
298
+ self.deltas_x0s.append(delta_plus_x0_images)
299
+
300
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
301
+ # if self._config.breakdown != "no_breakdown":
302
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
303
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
304
+
305
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
306
+
307
+ if (
308
+ self._config.breakdown == "x_t_hat_c"
309
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
310
+ ):
311
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
312
+ edit_images_indices[0] : edit_images_indices[1]
313
+ ] # update x_t_hat_c to be x_t_hat_c_hat
314
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
315
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
316
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
317
+ )
318
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
319
+ edit_images_indices[0] : edit_images_indices[1]
320
+ ]
321
+ if timestep == self._timesteps[-1]:
322
+ torch.save(
323
+ self.vector_data,
324
+ os.path.join(
325
+ path,
326
+ f"{VECTOR_DATA_DICT}.pt",
327
+ ),
328
+ )
329
+ # p_to_p_force_perfect_reconstruction
330
+ if not self.time_measure_n:
331
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
332
+
333
+ if not return_dict:
334
+ return (x_t_minus_1,)
335
+
336
+ return DDIMSchedulerOutput(
337
+ prev_sample=x_t_minus_1,
338
+ pred_original_sample=None,
339
+ )
340
+
341
+ def create_xts(
342
+ noise_shift_delta,
343
+ noise_timesteps,
344
+ clean_step_timestep,
345
+ generator,
346
+ scheduler,
347
+ timesteps,
348
+ x_0,
349
+ no_add_noise=False,
350
+ ):
351
+ if noise_timesteps is None:
352
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
353
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
354
+
355
+ first_x_0_idx = len(noise_timesteps)
356
+ for i in range(len(noise_timesteps)):
357
+ if noise_timesteps[i] <= 0:
358
+ first_x_0_idx = i
359
+ break
360
+
361
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
362
+
363
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
364
+ noise = (
365
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
366
+ x_0.device
367
+ )
368
+ if not no_add_noise
369
+ else torch.zeros_like(x_0_expanded)
370
+ )
371
+ x_ts = scheduler.add_noise(
372
+ x_0_expanded,
373
+ noise,
374
+ torch.IntTensor(noise_timesteps),
375
+ )
376
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
377
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
378
+ x_ts += [x_0]
379
+ if clean_step_timestep > 0:
380
+ x_ts += [x_0]
381
+ return x_ts
382
+
383
+ def normalize(
384
+ z_t,
385
+ i,
386
+ max_norm_zs,
387
+ ):
388
+ max_norm = max_norm_zs[i]
389
+ if max_norm < 0:
390
+ return z_t, 1
391
+
392
+ norm = torch.norm(z_t)
393
+ if norm < max_norm:
394
+ return z_t, 1
395
+
396
+ coeff = max_norm / norm
397
+ z_t = z_t * coeff
398
+ return z_t, coeff
399
+
400
+ def decode_latents(latent, pipe):
401
+ latent_img = pipe.vae.decode(
402
+ latent / pipe.vae.config.scaling_factor, return_dict=False
403
+ )[0]
404
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
405
+
406
+ def deterministic_ddim_step(
407
+ model_output: torch.FloatTensor,
408
+ timestep: int,
409
+ sample: torch.FloatTensor,
410
+ eta: float = 0.0,
411
+ use_clipped_model_output: bool = False,
412
+ generator=None,
413
+ variance_noise: Optional[torch.FloatTensor] = None,
414
+ return_dict: bool = True,
415
+ scheduler=None,
416
+ ):
417
+
418
+ if scheduler.num_inference_steps is None:
419
+ raise ValueError(
420
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
421
+ )
422
+
423
+ prev_timestep = (
424
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
425
+ )
426
+
427
+ # 2. compute alphas, betas
428
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
429
+ alpha_prod_t_prev = (
430
+ scheduler.alphas_cumprod[prev_timestep]
431
+ if prev_timestep >= 0
432
+ else scheduler.final_alpha_cumprod
433
+ )
434
+
435
+ beta_prod_t = 1 - alpha_prod_t
436
+
437
+ if scheduler.config.prediction_type == "epsilon":
438
+ pred_original_sample = (
439
+ sample - beta_prod_t ** (0.5) * model_output
440
+ ) / alpha_prod_t ** (0.5)
441
+ pred_epsilon = model_output
442
+ elif scheduler.config.prediction_type == "sample":
443
+ pred_original_sample = model_output
444
+ pred_epsilon = (
445
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
446
+ ) / beta_prod_t ** (0.5)
447
+ elif scheduler.config.prediction_type == "v_prediction":
448
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
449
+ beta_prod_t**0.5
450
+ ) * model_output
451
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
452
+ else:
453
+ raise ValueError(
454
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
455
+ " `v_prediction`"
456
+ )
457
+
458
+ # 4. Clip or threshold "predicted x_0"
459
+ if scheduler.config.thresholding:
460
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
461
+ elif scheduler.config.clip_sample:
462
+ pred_original_sample = pred_original_sample.clamp(
463
+ -scheduler.config.clip_sample_range,
464
+ scheduler.config.clip_sample_range,
465
+ )
466
+
467
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
468
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
469
+ variance = scheduler._get_variance(timestep, prev_timestep)
470
+ std_dev_t = eta * variance ** (0.5)
471
+
472
+ if use_clipped_model_output:
473
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
474
+ pred_epsilon = (
475
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
476
+ ) / beta_prod_t ** (0.5)
477
+
478
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
479
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
480
+ 0.5
481
+ ) * pred_epsilon
482
+
483
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
484
+ prev_sample = (
485
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
486
+ )
487
+ return prev_sample
488
+
489
+
490
+ def deterministic_euler_step(
491
+ model_output: torch.FloatTensor,
492
+ timestep: Union[float, torch.FloatTensor],
493
+ sample: torch.FloatTensor,
494
+ eta,
495
+ use_clipped_model_output,
496
+ generator,
497
+ variance_noise,
498
+ return_dict,
499
+ scheduler,
500
+ ):
501
+ """
502
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
503
+ process from the learned model outputs (most often the predicted noise).
504
+
505
+ Args:
506
+ model_output (`torch.FloatTensor`):
507
+ The direct output from learned diffusion model.
508
+ timestep (`float`):
509
+ The current discrete timestep in the diffusion chain.
510
+ sample (`torch.FloatTensor`):
511
+ A current instance of a sample created by the diffusion process.
512
+ generator (`torch.Generator`, *optional*):
513
+ A random number generator.
514
+ return_dict (`bool`):
515
+ Whether or not to return a
516
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
517
+
518
+ Returns:
519
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
520
+ If return_dict is `True`,
521
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
522
+ otherwise a tuple is returned where the first element is the sample tensor.
523
+
524
+ """
525
+
526
+ if (
527
+ isinstance(timestep, int)
528
+ or isinstance(timestep, torch.IntTensor)
529
+ or isinstance(timestep, torch.LongTensor)
530
+ ):
531
+ raise ValueError(
532
+ (
533
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
534
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
535
+ " one of the `scheduler.timesteps` as a timestep."
536
+ ),
537
+ )
538
+
539
+ if scheduler.step_index is None:
540
+ scheduler._init_step_index(timestep)
541
+
542
+ sigma = scheduler.sigmas[scheduler.step_index]
543
+
544
+ # Upcast to avoid precision issues when computing prev_sample
545
+ sample = sample.to(torch.float32)
546
+
547
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
548
+ if scheduler.config.prediction_type == "epsilon":
549
+ pred_original_sample = sample - sigma * model_output
550
+ elif scheduler.config.prediction_type == "v_prediction":
551
+ # * c_out + input * c_skip
552
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
553
+ sample / (sigma**2 + 1)
554
+ )
555
+ elif scheduler.config.prediction_type == "sample":
556
+ raise NotImplementedError("prediction_type not implemented yet: sample")
557
+ else:
558
+ raise ValueError(
559
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
560
+ )
561
+
562
+ sigma_from = scheduler.sigmas[scheduler.step_index]
563
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
564
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
565
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
566
+
567
+ # 2. Convert to an ODE derivative
568
+ derivative = (sample - pred_original_sample) / sigma
569
+
570
+ dt = sigma_down - sigma
571
+
572
+ prev_sample = sample + derivative * dt
573
+
574
+ # Cast sample back to model compatible dtype
575
+ prev_sample = prev_sample.to(model_output.dtype)
576
+
577
+ # upon completion increase step index by one
578
+ scheduler._step_index += 1
579
+
580
+ return prev_sample
581
+
582
+
583
+ def deterministic_non_ancestral_euler_step(
584
+ model_output: torch.FloatTensor,
585
+ timestep: Union[float, torch.FloatTensor],
586
+ sample: torch.FloatTensor,
587
+ eta: float = 0.0,
588
+ use_clipped_model_output: bool = False,
589
+ s_churn: float = 0.0,
590
+ s_tmin: float = 0.0,
591
+ s_tmax: float = float("inf"),
592
+ s_noise: float = 1.0,
593
+ generator: Optional[torch.Generator] = None,
594
+ variance_noise: Optional[torch.FloatTensor] = None,
595
+ return_dict: bool = True,
596
+ scheduler=None,
597
+ ):
598
+ """
599
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
600
+ process from the learned model outputs (most often the predicted noise).
601
+
602
+ Args:
603
+ model_output (`torch.FloatTensor`):
604
+ The direct output from learned diffusion model.
605
+ timestep (`float`):
606
+ The current discrete timestep in the diffusion chain.
607
+ sample (`torch.FloatTensor`):
608
+ A current instance of a sample created by the diffusion process.
609
+ s_churn (`float`):
610
+ s_tmin (`float`):
611
+ s_tmax (`float`):
612
+ s_noise (`float`, defaults to 1.0):
613
+ Scaling factor for noise added to the sample.
614
+ generator (`torch.Generator`, *optional*):
615
+ A random number generator.
616
+ return_dict (`bool`):
617
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
618
+ tuple.
619
+
620
+ Returns:
621
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
622
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
623
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
624
+ """
625
+
626
+ if (
627
+ isinstance(timestep, int)
628
+ or isinstance(timestep, torch.IntTensor)
629
+ or isinstance(timestep, torch.LongTensor)
630
+ ):
631
+ raise ValueError(
632
+ (
633
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
634
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
635
+ " one of the `scheduler.timesteps` as a timestep."
636
+ ),
637
+ )
638
+
639
+ if not scheduler.is_scale_input_called:
640
+ logger.warning(
641
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
642
+ "See `StableDiffusionPipeline` for a usage example."
643
+ )
644
+
645
+ if scheduler.step_index is None:
646
+ scheduler._init_step_index(timestep)
647
+
648
+ # Upcast to avoid precision issues when computing prev_sample
649
+ sample = sample.to(torch.float32)
650
+
651
+ sigma = scheduler.sigmas[scheduler.step_index]
652
+
653
+ gamma = (
654
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
655
+ if s_tmin <= sigma <= s_tmax
656
+ else 0.0
657
+ )
658
+
659
+ sigma_hat = sigma * (gamma + 1)
660
+
661
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
662
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
663
+ # backwards compatibility
664
+ if (
665
+ scheduler.config.prediction_type == "original_sample"
666
+ or scheduler.config.prediction_type == "sample"
667
+ ):
668
+ pred_original_sample = model_output
669
+ elif scheduler.config.prediction_type == "epsilon":
670
+ pred_original_sample = sample - sigma_hat * model_output
671
+ elif scheduler.config.prediction_type == "v_prediction":
672
+ # denoised = model_output * c_out + input * c_skip
673
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
674
+ sample / (sigma**2 + 1)
675
+ )
676
+ else:
677
+ raise ValueError(
678
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
679
+ )
680
+
681
+ # 2. Convert to an ODE derivative
682
+ derivative = (sample - pred_original_sample) / sigma_hat
683
+
684
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
685
+
686
+ prev_sample = sample + derivative * dt
687
+
688
+ # Cast sample back to model compatible dtype
689
+ prev_sample = prev_sample.to(model_output.dtype)
690
+
691
+ # upon completion increase step index by one
692
+ scheduler._step_index += 1
693
+
694
+ return prev_sample
695
+
696
+
697
+ def deterministic_ddpm_step(
698
+ model_output: torch.FloatTensor,
699
+ timestep: Union[float, torch.FloatTensor],
700
+ sample: torch.FloatTensor,
701
+ eta,
702
+ use_clipped_model_output,
703
+ generator,
704
+ variance_noise,
705
+ return_dict,
706
+ scheduler,
707
+ ):
708
+ """
709
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
710
+ process from the learned model outputs (most often the predicted noise).
711
+
712
+ Args:
713
+ model_output (`torch.FloatTensor`):
714
+ The direct output from learned diffusion model.
715
+ timestep (`float`):
716
+ The current discrete timestep in the diffusion chain.
717
+ sample (`torch.FloatTensor`):
718
+ A current instance of a sample created by the diffusion process.
719
+ generator (`torch.Generator`, *optional*):
720
+ A random number generator.
721
+ return_dict (`bool`, *optional*, defaults to `True`):
722
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
723
+
724
+ Returns:
725
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
726
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
727
+ tuple is returned where the first element is the sample tensor.
728
+
729
+ """
730
+ t = timestep
731
+
732
+ prev_t = scheduler.previous_timestep(t)
733
+
734
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
735
+ "learned",
736
+ "learned_range",
737
+ ]:
738
+ model_output, predicted_variance = torch.split(
739
+ model_output, sample.shape[1], dim=1
740
+ )
741
+ else:
742
+ predicted_variance = None
743
+
744
+ # 1. compute alphas, betas
745
+ alpha_prod_t = scheduler.alphas_cumprod[t]
746
+ alpha_prod_t_prev = (
747
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
748
+ )
749
+ beta_prod_t = 1 - alpha_prod_t
750
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
751
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
752
+ current_beta_t = 1 - current_alpha_t
753
+
754
+ # 2. compute predicted original sample from predicted noise also called
755
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
756
+ if scheduler.config.prediction_type == "epsilon":
757
+ pred_original_sample = (
758
+ sample - beta_prod_t ** (0.5) * model_output
759
+ ) / alpha_prod_t ** (0.5)
760
+ elif scheduler.config.prediction_type == "sample":
761
+ pred_original_sample = model_output
762
+ elif scheduler.config.prediction_type == "v_prediction":
763
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
764
+ beta_prod_t**0.5
765
+ ) * model_output
766
+ else:
767
+ raise ValueError(
768
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
769
+ " `v_prediction` for the DDPMScheduler."
770
+ )
771
+
772
+ # 3. Clip or threshold "predicted x_0"
773
+ if scheduler.config.thresholding:
774
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
775
+ elif scheduler.config.clip_sample:
776
+ pred_original_sample = pred_original_sample.clamp(
777
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
778
+ )
779
+
780
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
781
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
782
+ pred_original_sample_coeff = (
783
+ alpha_prod_t_prev ** (0.5) * current_beta_t
784
+ ) / beta_prod_t
785
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
786
+
787
+ # 5. Compute predicted previous sample µ_t
788
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
789
+ pred_prev_sample = (
790
+ pred_original_sample_coeff * pred_original_sample
791
+ + current_sample_coeff * sample
792
+ )
793
+
794
+ return pred_prev_sample
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ml-collections
2
+ gradio
3
+ torch
4
+ torchvision
5
+ diffusers
6
+ transformers
7
+ accelerate
8
+ mediapipe
9
+ spaces
10
+ sentencepiece
11
+ compel
12
+ gfpgan
13
+ git+https://github.com/XPixelGroup/BasicSR@master
14
+ facexlib
15
+ realesrgan
16
+ controlnet_aux
17
+ peft
run_configs/noise_shift_3_steps.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta_reconstruct: 1
4
+ eta_retrieve: 1
5
+ max_norm_zs: [-1, -1, 15.5]
6
+ model: "stabilityai/sdxl-turbo"
7
+ noise_shift_delta: 1
8
+ noise_timesteps: [599, 299, 0]
9
+ timesteps: [799, 499, 199]
10
+ num_steps_inversion: 5
11
+ step_start: 1
12
+ real_cfg_scale: 0
13
+ real_cfg_scale_save: 0
14
+ scheduler_type: "ddpm"
15
+ seed: 2
16
+ self_r: 0.5
17
+ ws1: 1.5
18
+ ws2: 1
19
+ clean_step_timestep: 0
run_configs/noise_shift_guidance_1_5.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta: 1
4
+ max_norm_zs: [-1, -1, -1, 15.5]
5
+ model: ""
6
+ noise_shift_delta: 1
7
+ noise_timesteps: null
8
+ num_steps_inversion: 20
9
+ step_start: 5
10
+ real_cfg_scale: 0
11
+ real_cfg_scale_save: 0
12
+ scheduler_type: "ddpm"
13
+ seed: 2
14
+ self_r: 0.5
15
+ timesteps: null
16
+ ws1: 1.5
17
+ ws2: 1
18
+ clean_step_timestep: 0
segment_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mediapipe as mp
3
+ import uuid
4
+
5
+ from PIL import Image
6
+ from mediapipe.tasks import python
7
+ from mediapipe.tasks.python import vision
8
+ from scipy.ndimage import binary_dilation
9
+ from croper import Croper
10
+
11
+ segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
12
+ base_options = python.BaseOptions(model_asset_path=segment_model)
13
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
14
+ segmenter = vision.ImageSegmenter.create_from_options(options)
15
+
16
+ def restore_result(croper, category, generated_image):
17
+ square_length = croper.square_length
18
+ generated_image = generated_image.resize((square_length, square_length))
19
+
20
+ cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
21
+ cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
22
+
23
+ restored_image = croper.input_image.copy()
24
+ restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
25
+
26
+ extension = 'png'
27
+ if restored_image.mode == 'RGBA':
28
+ extension = 'png'
29
+ else:
30
+ extension = 'jpg'
31
+
32
+ path = f"output/{uuid.uuid4()}.{extension}"
33
+ restored_image.save(path)
34
+
35
+ return restored_image, path
36
+
37
+ def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
38
+ mask_size = int(input_size)
39
+ mask_expansion = int(mask_expansion)
40
+
41
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
42
+ segmentation_result = segmenter.segment(image)
43
+ category_mask = segmentation_result.category_mask
44
+ category_mask_np = category_mask.numpy_view()
45
+
46
+ if category == "hair":
47
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
48
+ elif category == "clothes":
49
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
50
+ elif category == "face":
51
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
52
+ else:
53
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
54
+
55
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
56
+ croper.corp_mask_image()
57
+ origin_area_image = croper.resized_square_image
58
+
59
+ return origin_area_image, croper
60
+
61
+ def get_face_mask(category_mask_np, dilation=1):
62
+ face_skin_mask = category_mask_np == 3
63
+ if dilation > 0:
64
+ face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
65
+
66
+ return face_skin_mask
67
+
68
+ def get_clothes_mask(category_mask_np, dilation=1):
69
+ body_skin_mask = category_mask_np == 2
70
+ clothes_mask = category_mask_np == 4
71
+ combined_mask = np.logical_or(body_skin_mask, clothes_mask)
72
+ combined_mask = binary_dilation(combined_mask, iterations=4)
73
+ if dilation > 0:
74
+ combined_mask = binary_dilation(combined_mask, iterations=dilation)
75
+ return combined_mask
76
+
77
+ def get_hair_mask(category_mask_np, dilation=1):
78
+ hair_mask = category_mask_np == 1
79
+ if dilation > 0:
80
+ hair_mask = binary_dilation(hair_mask, iterations=dilation)
81
+ return hair_mask
82
+
83
+ def get_restore_mask_image(croper, category, generated_image):
84
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
85
+ segmentation_result = segmenter.segment(image)
86
+ category_mask = segmentation_result.category_mask
87
+ category_mask_np = category_mask.numpy_view()
88
+
89
+ if category == "hair":
90
+ target_mask = get_hair_mask(category_mask_np, 0)
91
+ elif category == "clothes":
92
+ target_mask = get_clothes_mask(category_mask_np, 0)
93
+ elif category == "face":
94
+ target_mask = get_face_mask(category_mask_np, 0)
95
+
96
+ combined_mask = np.logical_or(target_mask, croper.corp_mask)
97
+ mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
98
+ return mask_image