Spaces:
Configuration error
Configuration error
Upload src/pipelines/context.py with huggingface_hub
Browse files- src/pipelines/context.py +76 -0
src/pipelines/context.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: Adapted from cli
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def ordered_halving(val):
|
8 |
+
bin_str = f"{val:064b}"
|
9 |
+
bin_flip = bin_str[::-1]
|
10 |
+
as_int = int(bin_flip, 2)
|
11 |
+
|
12 |
+
return as_int / (1 << 64)
|
13 |
+
|
14 |
+
|
15 |
+
def uniform(
|
16 |
+
step: int = ...,
|
17 |
+
num_steps: Optional[int] = None,
|
18 |
+
num_frames: int = ...,
|
19 |
+
context_size: Optional[int] = None,
|
20 |
+
context_stride: int = 3,
|
21 |
+
context_overlap: int = 4,
|
22 |
+
closed_loop: bool = True,
|
23 |
+
):
|
24 |
+
if num_frames <= context_size:
|
25 |
+
yield list(range(num_frames))
|
26 |
+
return
|
27 |
+
|
28 |
+
context_stride = min(
|
29 |
+
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
|
30 |
+
)
|
31 |
+
|
32 |
+
for context_step in 1 << np.arange(context_stride):
|
33 |
+
pad = int(round(num_frames * ordered_halving(step)))
|
34 |
+
for j in range(
|
35 |
+
int(ordered_halving(step) * context_step) + pad,
|
36 |
+
num_frames + pad + (0 if closed_loop else -context_overlap),
|
37 |
+
(context_size * context_step - context_overlap),
|
38 |
+
):
|
39 |
+
yield [
|
40 |
+
e % num_frames
|
41 |
+
for e in range(j, j + context_size * context_step, context_step)
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def get_context_scheduler(name: str) -> Callable:
|
46 |
+
if name == "uniform":
|
47 |
+
return uniform
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Unknown context_overlap policy {name}")
|
50 |
+
|
51 |
+
|
52 |
+
def get_total_steps(
|
53 |
+
scheduler,
|
54 |
+
timesteps: List[int],
|
55 |
+
num_steps: Optional[int] = None,
|
56 |
+
num_frames: int = ...,
|
57 |
+
context_size: Optional[int] = None,
|
58 |
+
context_stride: int = 3,
|
59 |
+
context_overlap: int = 4,
|
60 |
+
closed_loop: bool = True,
|
61 |
+
):
|
62 |
+
return sum(
|
63 |
+
len(
|
64 |
+
list(
|
65 |
+
scheduler(
|
66 |
+
i,
|
67 |
+
num_steps,
|
68 |
+
num_frames,
|
69 |
+
context_size,
|
70 |
+
context_stride,
|
71 |
+
context_overlap,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
)
|
75 |
+
for i in range(len(timesteps))
|
76 |
+
)
|