Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from typing import Callable, Optional, List | |
def ordered_halving(val): | |
bin_str = f"{val:064b}" | |
bin_flip = bin_str[::-1] | |
as_int = int(bin_flip, 2) | |
return as_int / (1 << 64) | |
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: | |
prev_val = -1 | |
for i, val in enumerate(window): | |
val = val % num_frames | |
if val < prev_val: | |
return True, i | |
prev_val = val | |
return False, -1 | |
def shift_window_to_start(window: list[int], num_frames: int): | |
start_val = window[0] | |
for i in range(len(window)): | |
# 1) subtract each element by start_val to move vals relative to the start of all frames | |
# 2) add num_frames and take modulus to get adjusted vals | |
window[i] = ((window[i] - start_val) + num_frames) % num_frames | |
def shift_window_to_end(window: list[int], num_frames: int): | |
# 1) shift window to start | |
shift_window_to_start(window, num_frames) | |
end_val = window[-1] | |
end_delta = num_frames - end_val - 1 | |
for i in range(len(window)): | |
# 2) add end_delta to each val to slide windows to end | |
window[i] = window[i] + end_delta | |
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: | |
all_indexes = list(range(num_frames)) | |
for w in windows: | |
for val in w: | |
try: | |
all_indexes.remove(val) | |
except ValueError: | |
pass | |
return all_indexes | |
def uniform_looped( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
if num_frames <= context_size: | |
yield list(range(num_frames)) | |
return | |
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(step))) | |
for j in range( | |
int(ordered_halving(step) * context_step) + pad, | |
num_frames + pad + (0 if closed_loop else -context_overlap), | |
(context_size * context_step - context_overlap), | |
): | |
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] | |
#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) | |
def uniform_standard( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
windows = [] | |
if num_frames <= context_size: | |
windows.append(list(range(num_frames))) | |
return windows | |
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(step))) | |
for j in range( | |
int(ordered_halving(step) * context_step) + pad, | |
num_frames + pad + (0 if closed_loop else -context_overlap), | |
(context_size * context_step - context_overlap), | |
): | |
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) | |
# now that windows are created, shift any windows that loop, and delete duplicate windows | |
delete_idxs = [] | |
win_i = 0 | |
while win_i < len(windows): | |
# if window is rolls over itself, need to shift it | |
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) | |
if is_roll: | |
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides | |
shift_window_to_end(windows[win_i], num_frames=num_frames) | |
# check if next window (cyclical) is missing roll_val | |
if roll_val not in windows[(win_i+1) % len(windows)]: | |
# need to insert new window here - just insert window starting at roll_val | |
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) | |
# delete window if it's not unique | |
for pre_i in range(0, win_i): | |
if windows[win_i] == windows[pre_i]: | |
delete_idxs.append(win_i) | |
break | |
win_i += 1 | |
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation | |
delete_idxs.reverse() | |
for i in delete_idxs: | |
windows.pop(i) | |
return windows | |
def static_standard( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
windows = [] | |
if num_frames <= context_size: | |
windows.append(list(range(num_frames))) | |
return windows | |
# always return the same set of windows | |
delta = context_size - context_overlap | |
for start_idx in range(0, num_frames, delta): | |
# if past the end of frames, move start_idx back to allow same context_length | |
ending = start_idx + context_size | |
if ending >= num_frames: | |
final_delta = ending - num_frames | |
final_start_idx = start_idx - final_delta | |
windows.append(list(range(final_start_idx, final_start_idx + context_size))) | |
break | |
windows.append(list(range(start_idx, start_idx + context_size))) | |
return windows | |
def get_context_scheduler(name: str) -> Callable: | |
if name == "uniform_looped": | |
return uniform_looped | |
elif name == "uniform_standard": | |
return uniform_standard | |
elif name == "static_standard": | |
return static_standard | |
else: | |
raise ValueError(f"Unknown context_overlap policy {name}") | |
def get_total_steps( | |
scheduler, | |
timesteps: List[int], | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
return sum( | |
len( | |
list( | |
scheduler( | |
i, | |
num_steps, | |
num_frames, | |
context_size, | |
context_stride, | |
context_overlap, | |
) | |
) | |
) | |
for i in range(len(timesteps)) | |
) | |