File size: 6,533 Bytes
41b9d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from typing import Optional

import torch
from audiotools import AudioSignal

from .util import scalar_to_batch_tensor

def _gamma(r):
    return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)

def _invgamma(y):
    if not torch.is_tensor(y):
        y = torch.tensor(y)[None]
    return 2 * y.acos() / torch.pi

def full_mask(x: torch.Tensor):
    assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
    return torch.ones_like(x).long()

def empty_mask(x: torch.Tensor):
    assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
    return torch.zeros_like(x).long()

def apply_mask(
        x: torch.Tensor, 
        mask: torch.Tensor, 
        mask_token: int
    ):
    assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
    assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}" 
    assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
    assert ~torch.any(mask > 1), "mask must be binary"
    assert ~torch.any(mask < 0), "mask must be binary"

    fill_x = torch.full_like(x, mask_token)
    x = x * (1 - mask) + fill_x * mask

    return x, mask

def random(
    x: torch.Tensor,
    r: torch.Tensor
):
    assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
    if not isinstance(r, torch.Tensor):
        r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)

    r = _gamma(r)[:, None, None]
    probs = torch.ones_like(x) * r

    mask = torch.bernoulli(probs)
    mask = mask.round().long()

    return mask

def linear_random(
    x: torch.Tensor,
    r: torch.Tensor,
):
    assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
    if not isinstance(r, torch.Tensor):
        r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
        r = r[:, None, None]

    probs = torch.ones_like(x).to(x.device).float()
    # expand to batch and codebook dims
    probs = probs.expand(x.shape[0], x.shape[1], -1)
    probs = probs * r

    mask = torch.bernoulli(probs)
    mask = mask.round().long()

    return mask

def inpaint(x: torch.Tensor, 
    n_prefix,
    n_suffix,
):
    assert n_prefix is not None
    assert n_suffix is not None
    
    mask = full_mask(x)

    # if we have a prefix or suffix, set their mask prob to 0
    if n_prefix > 0:
        if not isinstance(n_prefix, torch.Tensor):
            n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device) 
        for i, n in enumerate(n_prefix):
            if n > 0:
                mask[i, :, :n] = 0.0
    if n_suffix > 0:
        if not isinstance(n_suffix, torch.Tensor):
            n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
        for i, n in enumerate(n_suffix):
            if n > 0:
                mask[i, :, -n:] = 0.0

    
    return mask

def periodic_mask(x: torch.Tensor, 
                period: int,width: int = 1, 
                random_roll=False,
    ):
    mask = full_mask(x)
    if period == 0:
        return mask

    if not isinstance(period, torch.Tensor):
        period = scalar_to_batch_tensor(period, x.shape[0])
    for i, factor in enumerate(period):
        if factor == 0:
            continue
        for j in range(mask.shape[-1]):
            if j % factor == 0:
                # figure out how wide the mask should be
                j_start = max(0, j - width // 2  )
                j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1 
                # flip a coin for each position in the mask
                j_mask = torch.bernoulli(torch.ones(j_end - j_start))
                assert torch.all(j_mask == 1)
                j_fill = torch.ones_like(j_mask) * (1 - j_mask)
                assert torch.all(j_fill == 0)
                # fill
                mask[i, :, j_start:j_end] = j_fill
    if random_roll:
        # add a random offset to the mask
        offset = torch.randint(0, period[0], (1,))
        mask = torch.roll(mask, offset.item(), dims=-1)

    return mask

def codebook_unmask(
    mask: torch.Tensor, 
    n_conditioning_codebooks: int
):
    if n_conditioning_codebooks == None:
        return mask
    # if we have any conditioning codebooks, set their mask  to 0
    mask = mask.clone()
    mask[:, :n_conditioning_codebooks, :] = 0
    return mask

def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
    mask = mask.clone()
    mask[:, val1:, :] = 1
    # val2 = val2 or val1
    # vs = torch.linspace(val1, val2, mask.shape[1])
    # for t, v in enumerate(vs):
    #     v = int(v)
    #     mask[:, v:, t] = 1 

    return mask

def mask_and(
    mask1: torch.Tensor, 
    mask2: torch.Tensor
):
    assert mask1.shape == mask2.shape, "masks must be same shape"
    return torch.min(mask1, mask2)

def dropout(
    mask: torch.Tensor,
    p: float,
):
    assert 0 <= p <= 1, "p must be between 0 and 1"
    assert mask.max() <= 1, "mask must be binary"
    assert mask.min() >= 0, "mask must be binary"
    mask = (~mask.bool()).float()
    mask = torch.bernoulli(mask * (1 - p))
    mask = ~mask.round().bool()
    return mask.long()

def mask_or(
    mask1: torch.Tensor, 
    mask2: torch.Tensor
):
    assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
    assert mask1.max() <= 1, "mask1 must be binary"
    assert mask2.max() <= 1, "mask2 must be binary"
    assert mask1.min() >= 0, "mask1 must be binary"
    assert mask2.min() >= 0, "mask2 must be binary"
    return (mask1 + mask2).clamp(0, 1)

def time_stretch_mask(
    x: torch.Tensor, 
    stretch_factor: int,
):
    assert stretch_factor >= 1, "stretch factor must be >= 1"
    c_seq_len = x.shape[-1]
    x = x.repeat_interleave(stretch_factor, dim=-1)

    # trim cz to the original length
    x = x[:, :, :c_seq_len]

    mask = periodic_mask(x, stretch_factor, width=1)
    return mask

def onset_mask(
    sig: AudioSignal, 
    z: torch.Tensor,
    interface,
    width: int = 1, 
):
    import librosa

    onset_frame_idxs = librosa.onset.onset_detect(
        y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate, 
        hop_length=interface.codec.hop_length,
        backtrack=True,
    )
    if len(onset_frame_idxs) == 0:
        print("no onsets detected")
    print("onset_frame_idxs", onset_frame_idxs)
    print("mask shape", z.shape)

    mask = torch.ones_like(z)
    for idx in onset_frame_idxs:
        mask[:, :, idx-width:idx+width] = 0

    return mask



if __name__ == "__main__":
    sig = AudioSignal("assets/example.wav")