Spaces:
Sleeping
Sleeping
File size: 13,042 Bytes
07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
import random
import numpy as np
import torch
from videosys.utils.logging import logger
PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool,
cross_threshold: list,
cross_gap: int,
spatial_broadcast: bool,
spatial_threshold: list,
spatial_gap: int,
temporal_broadcast: bool,
temporal_threshold: list,
temporal_gap: int,
diffusion_skip: bool,
diffusion_timestep_respacing: list,
diffusion_skip_timestep: list,
mlp_skip: bool,
mlp_spatial_skip_config: dict,
mlp_temporal_skip_config: dict,
full_broadcast: bool = False,
full_threshold: list = None,
full_gap: int = 1,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_gap = cross_gap
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_gap = spatial_gap
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_gap = temporal_gap
self.diffusion_skip = diffusion_skip
self.diffusion_timestep_respacing = diffusion_timestep_respacing
self.diffusion_skip_timestep = diffusion_skip_timestep
self.mlp_skip = mlp_skip
self.mlp_spatial_skip_config = mlp_spatial_skip_config
self.mlp_temporal_skip_config = mlp_temporal_skip_config
self.temporal_mlp_outputs = {}
self.spatial_mlp_outputs = {}
self.full_broadcast = full_broadcast
self.full_threshold = full_threshold
self.full_gap = full_gap
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init PABManager. steps: {config.steps}."
init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
logger.info(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
and (count % self.config.cross_gap != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_temporal(self, timestep: int, count: int):
if (
self.config.temporal_broadcast
and (timestep is not None)
and (count % self.config.temporal_gap != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
and (count % self.config.spatial_gap != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
if (
self.config.full_broadcast
and (timestep is not None)
and (count % self.config.full_gap != 0)
and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
for key in config:
if key not in all_timesteps:
continue
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not self.config.mlp_skip:
return False, None, False, None
if is_temporal:
cur_config = self.config.mlp_temporal_skip_config
else:
cur_config = self.config.mlp_spatial_skip_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_skip
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_skip
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False
return flag, count, next_flag, skip_range
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
else:
self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
if self.config.temporal_mlp_outputs is not None
else None
)
else:
skip_output = (
self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
if self.config.spatial_mlp_outputs is not None
else None
)
if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
else:
del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)
return skip_output
def get_spatial_mlp_outputs(self):
return self.config.spatial_mlp_outputs
def get_temporal_mlp_outputs(self):
return self.config.temporal_mlp_outputs
def set_pab_manager(config: PABConfig):
global PAB_MANAGER
PAB_MANAGER = PABManager(config)
def enable_pab():
if PAB_MANAGER is None:
return False
return (
PAB_MANAGER.config.cross_broadcast
or PAB_MANAGER.config.spatial_broadcast
or PAB_MANAGER.config.temporal_broadcast
)
def update_steps(steps: int):
if PAB_MANAGER is not None:
PAB_MANAGER.config.steps = steps
def if_broadcast_cross(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_cross(timestep, count)
def if_broadcast_temporal(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
def if_broadcast_full(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
def get_diffusion_skip():
return enable_pab() and PAB_MANAGER.config.diffusion_skip
def get_diffusion_timestep_respacing():
return PAB_MANAGER.config.diffusion_timestep_respacing
def get_diffusion_skip_timestep():
return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
def space_timesteps(time_steps, time_bins):
num_bins = len(time_bins)
bin_size = time_steps // num_bins
result = []
for i, bin_count in enumerate(time_bins):
start = i * bin_size
end = start + bin_size
bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
result.extend(bin_steps)
result_tensor = torch.tensor(result, dtype=torch.int32)
sorted_tensor = torch.sort(result_tensor, descending=True).values
return sorted_tensor
def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
if isinstance(timesteps, list):
# If timesteps is a list, we assume each element is a tensor
timesteps_np = [t.cpu().numpy() for t in timesteps]
device = timesteps[0].device
else:
# If timesteps is a tensor
timesteps_np = timesteps.cpu().numpy()
device = timesteps.device
num_bins = len(diffusion_skip_timestep)
if isinstance(timesteps_np, list):
bin_size = len(timesteps_np) // num_bins
new_timesteps = []
for i in range(num_bins):
bin_start = i * bin_size
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
bin_timesteps = timesteps_np[bin_start:bin_end]
if diffusion_skip_timestep[i] == 0:
# If the bin is marked with 0, keep all timesteps
new_timesteps.extend(bin_timesteps)
elif diffusion_skip_timestep[i] == 1:
# If the bin is marked with 1, omit the last timestep in the bin
new_timesteps.extend(bin_timesteps[1:])
new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
else:
bin_size = len(timesteps_np) // num_bins
new_timesteps = []
for i in range(num_bins):
bin_start = i * bin_size
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
bin_timesteps = timesteps_np[bin_start:bin_end]
if diffusion_skip_timestep[i] == 0:
# If the bin is marked with 0, keep all timesteps
new_timesteps.extend(bin_timesteps)
elif diffusion_skip_timestep[i] == 1:
# If the bin is marked with 1, omit the last timestep in the bin
new_timesteps.extend(bin_timesteps[1:])
elif diffusion_skip_timestep[i] != 0:
# If the bin is marked with a non-zero value, randomly omit n timesteps
if len(bin_timesteps) > diffusion_skip_timestep[i]:
indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
timesteps_to_keep = [
timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
]
else:
timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
new_timesteps.extend(timesteps_to_keep)
new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
if isinstance(timesteps, list):
return new_timesteps_tensor
else:
return new_timesteps_tensor
|