Spaces:
Sleeping
Sleeping
import random | |
import numpy as np | |
import torch | |
from videosys.utils.logging import logger | |
PAB_MANAGER = None | |
class PABConfig: | |
def __init__( | |
self, | |
steps: int, | |
cross_broadcast: bool, | |
cross_threshold: list, | |
cross_gap: int, | |
spatial_broadcast: bool, | |
spatial_threshold: list, | |
spatial_gap: int, | |
temporal_broadcast: bool, | |
temporal_threshold: list, | |
temporal_gap: int, | |
diffusion_skip: bool, | |
diffusion_timestep_respacing: list, | |
diffusion_skip_timestep: list, | |
mlp_skip: bool, | |
mlp_spatial_skip_config: dict, | |
mlp_temporal_skip_config: dict, | |
full_broadcast: bool = False, | |
full_threshold: list = None, | |
full_gap: int = 1, | |
): | |
self.steps = steps | |
self.cross_broadcast = cross_broadcast | |
self.cross_threshold = cross_threshold | |
self.cross_gap = cross_gap | |
self.spatial_broadcast = spatial_broadcast | |
self.spatial_threshold = spatial_threshold | |
self.spatial_gap = spatial_gap | |
self.temporal_broadcast = temporal_broadcast | |
self.temporal_threshold = temporal_threshold | |
self.temporal_gap = temporal_gap | |
self.diffusion_skip = diffusion_skip | |
self.diffusion_timestep_respacing = diffusion_timestep_respacing | |
self.diffusion_skip_timestep = diffusion_skip_timestep | |
self.mlp_skip = mlp_skip | |
self.mlp_spatial_skip_config = mlp_spatial_skip_config | |
self.mlp_temporal_skip_config = mlp_temporal_skip_config | |
self.temporal_mlp_outputs = {} | |
self.spatial_mlp_outputs = {} | |
self.full_broadcast = full_broadcast | |
self.full_threshold = full_threshold | |
self.full_gap = full_gap | |
class PABManager: | |
def __init__(self, config: PABConfig): | |
self.config: PABConfig = config | |
init_prompt = f"Init PABManager. steps: {config.steps}." | |
init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}." | |
init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}." | |
init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}." | |
init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}." | |
logger.info(init_prompt) | |
def if_broadcast_cross(self, timestep: int, count: int): | |
if ( | |
self.config.cross_broadcast | |
and (timestep is not None) | |
and (count % self.config.cross_gap != 0) | |
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1]) | |
): | |
flag = True | |
else: | |
flag = False | |
count = (count + 1) % self.config.steps | |
return flag, count | |
def if_broadcast_temporal(self, timestep: int, count: int): | |
if ( | |
self.config.temporal_broadcast | |
and (timestep is not None) | |
and (count % self.config.temporal_gap != 0) | |
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1]) | |
): | |
flag = True | |
else: | |
flag = False | |
count = (count + 1) % self.config.steps | |
return flag, count | |
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int): | |
if ( | |
self.config.spatial_broadcast | |
and (timestep is not None) | |
and (count % self.config.spatial_gap != 0) | |
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1]) | |
): | |
flag = True | |
else: | |
flag = False | |
count = (count + 1) % self.config.steps | |
return flag, count | |
def if_broadcast_full(self, timestep: int, count: int, block_idx: int): | |
if ( | |
self.config.full_broadcast | |
and (timestep is not None) | |
and (count % self.config.full_gap != 0) | |
and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1]) | |
): | |
flag = True | |
else: | |
flag = False | |
count = (count + 1) % self.config.steps | |
return flag, count | |
def _is_t_in_skip_config(all_timesteps, timestep, config): | |
is_t_in_skip_config = False | |
for key in config: | |
if key not in all_timesteps: | |
continue | |
index = all_timesteps.index(key) | |
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])] | |
if timestep in skip_range: | |
is_t_in_skip_config = True | |
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]] | |
break | |
return is_t_in_skip_config, skip_range | |
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): | |
if not self.config.mlp_skip: | |
return False, None, False, None | |
if is_temporal: | |
cur_config = self.config.mlp_temporal_skip_config | |
else: | |
cur_config = self.config.mlp_spatial_skip_config | |
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config) | |
next_flag = False | |
if ( | |
self.config.mlp_skip | |
and (timestep is not None) | |
and (timestep in cur_config) | |
and (block_idx in cur_config[timestep]["block"]) | |
): | |
flag = False | |
next_flag = True | |
count = count + 1 | |
elif ( | |
self.config.mlp_skip | |
and (timestep is not None) | |
and (is_t_in_skip_config) | |
and (block_idx in cur_config[skip_range[0]]["block"]) | |
): | |
flag = True | |
count = 0 | |
else: | |
flag = False | |
return flag, count, next_flag, skip_range | |
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False): | |
if is_temporal: | |
self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output | |
else: | |
self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output | |
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False): | |
skip_start_t = skip_range[0] | |
if is_temporal: | |
skip_output = ( | |
self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None) | |
if self.config.temporal_mlp_outputs is not None | |
else None | |
) | |
else: | |
skip_output = ( | |
self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None) | |
if self.config.spatial_mlp_outputs is not None | |
else None | |
) | |
if skip_output is not None: | |
if timestep == skip_range[-1]: | |
# TODO: save memory | |
if is_temporal: | |
del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)] | |
else: | |
del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)] | |
else: | |
raise ValueError( | |
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}" | |
) | |
return skip_output | |
def get_spatial_mlp_outputs(self): | |
return self.config.spatial_mlp_outputs | |
def get_temporal_mlp_outputs(self): | |
return self.config.temporal_mlp_outputs | |
def set_pab_manager(config: PABConfig): | |
global PAB_MANAGER | |
PAB_MANAGER = PABManager(config) | |
def enable_pab(): | |
if PAB_MANAGER is None: | |
return False | |
return ( | |
PAB_MANAGER.config.cross_broadcast | |
or PAB_MANAGER.config.spatial_broadcast | |
or PAB_MANAGER.config.temporal_broadcast | |
) | |
def update_steps(steps: int): | |
if PAB_MANAGER is not None: | |
PAB_MANAGER.config.steps = steps | |
def if_broadcast_cross(timestep: int, count: int): | |
if not enable_pab(): | |
return False, count | |
return PAB_MANAGER.if_broadcast_cross(timestep, count) | |
def if_broadcast_temporal(timestep: int, count: int): | |
if not enable_pab(): | |
return False, count | |
return PAB_MANAGER.if_broadcast_temporal(timestep, count) | |
def if_broadcast_spatial(timestep: int, count: int, block_idx: int): | |
if not enable_pab(): | |
return False, count | |
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx) | |
def if_broadcast_full(timestep: int, count: int, block_idx: int): | |
if not enable_pab(): | |
return False, count | |
return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx) | |
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): | |
if not enable_pab(): | |
return False, count | |
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal) | |
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False): | |
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal) | |
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False): | |
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal) | |
def get_diffusion_skip(): | |
return enable_pab() and PAB_MANAGER.config.diffusion_skip | |
def get_diffusion_timestep_respacing(): | |
return PAB_MANAGER.config.diffusion_timestep_respacing | |
def get_diffusion_skip_timestep(): | |
return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep | |
def space_timesteps(time_steps, time_bins): | |
num_bins = len(time_bins) | |
bin_size = time_steps // num_bins | |
result = [] | |
for i, bin_count in enumerate(time_bins): | |
start = i * bin_size | |
end = start + bin_size | |
bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist() | |
result.extend(bin_steps) | |
result_tensor = torch.tensor(result, dtype=torch.int32) | |
sorted_tensor = torch.sort(result_tensor, descending=True).values | |
return sorted_tensor | |
def skip_diffusion_timestep(timesteps, diffusion_skip_timestep): | |
if isinstance(timesteps, list): | |
# If timesteps is a list, we assume each element is a tensor | |
timesteps_np = [t.cpu().numpy() for t in timesteps] | |
device = timesteps[0].device | |
else: | |
# If timesteps is a tensor | |
timesteps_np = timesteps.cpu().numpy() | |
device = timesteps.device | |
num_bins = len(diffusion_skip_timestep) | |
if isinstance(timesteps_np, list): | |
bin_size = len(timesteps_np) // num_bins | |
new_timesteps = [] | |
for i in range(num_bins): | |
bin_start = i * bin_size | |
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np) | |
bin_timesteps = timesteps_np[bin_start:bin_end] | |
if diffusion_skip_timestep[i] == 0: | |
# If the bin is marked with 0, keep all timesteps | |
new_timesteps.extend(bin_timesteps) | |
elif diffusion_skip_timestep[i] == 1: | |
# If the bin is marked with 1, omit the last timestep in the bin | |
new_timesteps.extend(bin_timesteps[1:]) | |
new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps] | |
else: | |
bin_size = len(timesteps_np) // num_bins | |
new_timesteps = [] | |
for i in range(num_bins): | |
bin_start = i * bin_size | |
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np) | |
bin_timesteps = timesteps_np[bin_start:bin_end] | |
if diffusion_skip_timestep[i] == 0: | |
# If the bin is marked with 0, keep all timesteps | |
new_timesteps.extend(bin_timesteps) | |
elif diffusion_skip_timestep[i] == 1: | |
# If the bin is marked with 1, omit the last timestep in the bin | |
new_timesteps.extend(bin_timesteps[1:]) | |
elif diffusion_skip_timestep[i] != 0: | |
# If the bin is marked with a non-zero value, randomly omit n timesteps | |
if len(bin_timesteps) > diffusion_skip_timestep[i]: | |
indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i])) | |
timesteps_to_keep = [ | |
timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove | |
] | |
else: | |
timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素 | |
new_timesteps.extend(timesteps_to_keep) | |
new_timesteps_tensor = torch.tensor(new_timesteps, device=device) | |
if isinstance(timesteps, list): | |
return new_timesteps_tensor | |
else: | |
return new_timesteps_tensor | |