File size: 2,988 Bytes
96d7ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List, Optional, Tuple, Union
import torch


from diffusers.utils.torch_utils import randn_tensor


def random_noise(
    tensor: torch.Tensor = None,
    shape: Tuple[int] = None,
    dtype: torch.dtype = None,
    device: torch.device = None,
    generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
    noise_offset: Optional[float] = None,  # typical value is 0.1
) -> torch.Tensor:
    if tensor is not None:
        shape = tensor.shape
        device = tensor.device
        dtype = tensor.dtype
    if isinstance(device, str):
        device = torch.device(device)
    noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator)
    if noise_offset is not None:
        # https://www.crosslabs.org//blog/diffusion-with-offset-noise
        noise += noise_offset * torch.randn(
            (tensor.shape[0], tensor.shape[1], 1, 1, 1), device
        )
    return noise


def video_fusion_noise(
    tensor: torch.Tensor = None,
    shape: Tuple[int] = None,
    dtype: torch.dtype = None,
    device: torch.device = None,
    w_ind_noise: float = 0.5,
    generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
    initial_common_noise: torch.Tensor = None,
) -> torch.Tensor:
    if tensor is not None:
        shape = tensor.shape
        device = tensor.device
        dtype = tensor.dtype
    if isinstance(device, str):
        device = torch.device(device)
    batch_size, c, t, h, w = shape
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    if not isinstance(generator, list):
        if initial_common_noise is not None:
            common_noise = initial_common_noise.to(device, dtype=dtype)
        else:
            common_noise = randn_tensor(
                (shape[0], shape[1], 1, shape[3], shape[4]),
                generator=generator,
                device=device,
                dtype=dtype,
            )  # common noise
        ind_noise = randn_tensor(
            shape,
            generator=generator,
            device=device,
            dtype=dtype,
        )  # individual noise
        s = torch.tensor(w_ind_noise, device=device, dtype=dtype)
        latents = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise
    else:
        latents = []
        for i in range(batch_size):
            latent = video_fusion_noise(
                shape=(1, c, t, h, w),
                dtype=dtype,
                device=device,
                w_ind_noise=w_ind_noise,
                generator=generator[i],
                initial_common_noise=initial_common_noise,
            )
            latents.append(latent)
        latents = torch.cat(latents, dim=0).to(device)
    return latents