File size: 4,285 Bytes
96d7ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# TODO: Adapted from cli
import math
from typing import Callable, List, Optional

import numpy as np

from mmcm.utils.itertools_util import generate_sample_idxs

# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py


def ordered_halving(val):
    bin_str = f"{val:064b}"
    bin_flip = bin_str[::-1]
    as_int = int(bin_flip, 2)

    return as_int / (1 << 64)


# TODO: closed_loop not work, to fix it
def uniform(
    step: int = ...,
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    if num_frames <= context_size:
        yield list(range(num_frames))
        return

    context_stride = min(
        context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
    )

    for context_step in 1 << np.arange(context_stride):
        pad = int(round(num_frames * ordered_halving(step)))
        for j in range(
            int(ordered_halving(step) * context_step) + pad,
            num_frames + pad + (0 if closed_loop else -context_overlap),
            (context_size * context_step - context_overlap),
        ):
            yield [
                e % num_frames
                for e in range(j, j + context_size * context_step, context_step)
            ]


def uniform_v2(
    step: int = ...,
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    return generate_sample_idxs(
        total=num_frames,
        window_size=context_size,
        step=context_size - context_overlap,
        sample_rate=1,
        drop_last=False,
    )


def get_context_scheduler(name: str) -> Callable:
    if name == "uniform":
        return uniform
    elif name == "uniform_v2":
        return uniform_v2
    else:
        raise ValueError(f"Unknown context_overlap policy {name}")


def get_total_steps(
    scheduler,
    timesteps: List[int],
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    return sum(
        len(
            list(
                scheduler(
                    i,
                    num_steps,
                    num_frames,
                    context_size,
                    context_stride,
                    context_overlap,
                )
            )
        )
        for i in range(len(timesteps))
    )


def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]:
    """if len(contexts)>=2 and the max value the oenultimate list same as  of the last list

    Args:
        List (_type_): _description_

    Returns:
        List[List[int]]: _description_
    """
    if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]:
        return contexts[:-1]
    else:
        return contexts


def prepare_global_context(
    context_schedule: str,
    num_inference_steps: int,
    time_size: int,
    context_frames: int,
    context_stride: int,
    context_overlap: int,
    context_batch_size: int,
):
    context_scheduler = get_context_scheduler(context_schedule)
    context_queue = list(
        context_scheduler(
            step=0,
            num_steps=num_inference_steps,
            num_frames=time_size,
            context_size=context_frames,
            context_stride=context_stride,
            context_overlap=context_overlap,
        )
    )
    # 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉
    # remove the last context if max index of the last context is the same as the max index of the second last context
    context_queue = drop_last_repeat_context(context_queue)
    num_context_batches = math.ceil(len(context_queue) / context_batch_size)
    global_context = []
    for i_tmp in range(num_context_batches):
        global_context.append(
            context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size]
        )
    return global_context