anchorxia's picture
add musev
96d7ad8
raw
history blame
4.29 kB
# TODO: Adapted from cli
import math
from typing import Callable, List, Optional
import numpy as np
from mmcm.utils.itertools_util import generate_sample_idxs
# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py
def ordered_halving(val):
bin_str = f"{val:064b}"
bin_flip = bin_str[::-1]
as_int = int(bin_flip, 2)
return as_int / (1 << 64)
# TODO: closed_loop not work, to fix it
def uniform(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
if num_frames <= context_size:
yield list(range(num_frames))
return
context_stride = min(
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
)
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
yield [
e % num_frames
for e in range(j, j + context_size * context_step, context_step)
]
def uniform_v2(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return generate_sample_idxs(
total=num_frames,
window_size=context_size,
step=context_size - context_overlap,
sample_rate=1,
drop_last=False,
)
def get_context_scheduler(name: str) -> Callable:
if name == "uniform":
return uniform
elif name == "uniform_v2":
return uniform_v2
else:
raise ValueError(f"Unknown context_overlap policy {name}")
def get_total_steps(
scheduler,
timesteps: List[int],
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return sum(
len(
list(
scheduler(
i,
num_steps,
num_frames,
context_size,
context_stride,
context_overlap,
)
)
)
for i in range(len(timesteps))
)
def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]:
"""if len(contexts)>=2 and the max value the oenultimate list same as of the last list
Args:
List (_type_): _description_
Returns:
List[List[int]]: _description_
"""
if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]:
return contexts[:-1]
else:
return contexts
def prepare_global_context(
context_schedule: str,
num_inference_steps: int,
time_size: int,
context_frames: int,
context_stride: int,
context_overlap: int,
context_batch_size: int,
):
context_scheduler = get_context_scheduler(context_schedule)
context_queue = list(
context_scheduler(
step=0,
num_steps=num_inference_steps,
num_frames=time_size,
context_size=context_frames,
context_stride=context_stride,
context_overlap=context_overlap,
)
)
# 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉
# remove the last context if max index of the last context is the same as the max index of the second last context
context_queue = drop_last_repeat_context(context_queue)
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
global_context = []
for i_tmp in range(num_context_batches):
global_context.append(
context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size]
)
return global_context