Spaces:
Runtime error
Runtime error
File size: 6,533 Bytes
41b9d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
from typing import Optional
import torch
from audiotools import AudioSignal
from .util import scalar_to_batch_tensor
def _gamma(r):
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
def _invgamma(y):
if not torch.is_tensor(y):
y = torch.tensor(y)[None]
return 2 * y.acos() / torch.pi
def full_mask(x: torch.Tensor):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
return torch.ones_like(x).long()
def empty_mask(x: torch.Tensor):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
return torch.zeros_like(x).long()
def apply_mask(
x: torch.Tensor,
mask: torch.Tensor,
mask_token: int
):
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
assert ~torch.any(mask > 1), "mask must be binary"
assert ~torch.any(mask < 0), "mask must be binary"
fill_x = torch.full_like(x, mask_token)
x = x * (1 - mask) + fill_x * mask
return x, mask
def random(
x: torch.Tensor,
r: torch.Tensor
):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
if not isinstance(r, torch.Tensor):
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
r = _gamma(r)[:, None, None]
probs = torch.ones_like(x) * r
mask = torch.bernoulli(probs)
mask = mask.round().long()
return mask
def linear_random(
x: torch.Tensor,
r: torch.Tensor,
):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
if not isinstance(r, torch.Tensor):
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
r = r[:, None, None]
probs = torch.ones_like(x).to(x.device).float()
# expand to batch and codebook dims
probs = probs.expand(x.shape[0], x.shape[1], -1)
probs = probs * r
mask = torch.bernoulli(probs)
mask = mask.round().long()
return mask
def inpaint(x: torch.Tensor,
n_prefix,
n_suffix,
):
assert n_prefix is not None
assert n_suffix is not None
mask = full_mask(x)
# if we have a prefix or suffix, set their mask prob to 0
if n_prefix > 0:
if not isinstance(n_prefix, torch.Tensor):
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
for i, n in enumerate(n_prefix):
if n > 0:
mask[i, :, :n] = 0.0
if n_suffix > 0:
if not isinstance(n_suffix, torch.Tensor):
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
for i, n in enumerate(n_suffix):
if n > 0:
mask[i, :, -n:] = 0.0
return mask
def periodic_mask(x: torch.Tensor,
period: int,width: int = 1,
random_roll=False,
):
mask = full_mask(x)
if period == 0:
return mask
if not isinstance(period, torch.Tensor):
period = scalar_to_batch_tensor(period, x.shape[0])
for i, factor in enumerate(period):
if factor == 0:
continue
for j in range(mask.shape[-1]):
if j % factor == 0:
# figure out how wide the mask should be
j_start = max(0, j - width // 2 )
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
# flip a coin for each position in the mask
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
assert torch.all(j_mask == 1)
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
assert torch.all(j_fill == 0)
# fill
mask[i, :, j_start:j_end] = j_fill
if random_roll:
# add a random offset to the mask
offset = torch.randint(0, period[0], (1,))
mask = torch.roll(mask, offset.item(), dims=-1)
return mask
def codebook_unmask(
mask: torch.Tensor,
n_conditioning_codebooks: int
):
if n_conditioning_codebooks == None:
return mask
# if we have any conditioning codebooks, set their mask to 0
mask = mask.clone()
mask[:, :n_conditioning_codebooks, :] = 0
return mask
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
mask = mask.clone()
mask[:, val1:, :] = 1
# val2 = val2 or val1
# vs = torch.linspace(val1, val2, mask.shape[1])
# for t, v in enumerate(vs):
# v = int(v)
# mask[:, v:, t] = 1
return mask
def mask_and(
mask1: torch.Tensor,
mask2: torch.Tensor
):
assert mask1.shape == mask2.shape, "masks must be same shape"
return torch.min(mask1, mask2)
def dropout(
mask: torch.Tensor,
p: float,
):
assert 0 <= p <= 1, "p must be between 0 and 1"
assert mask.max() <= 1, "mask must be binary"
assert mask.min() >= 0, "mask must be binary"
mask = (~mask.bool()).float()
mask = torch.bernoulli(mask * (1 - p))
mask = ~mask.round().bool()
return mask.long()
def mask_or(
mask1: torch.Tensor,
mask2: torch.Tensor
):
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
assert mask1.max() <= 1, "mask1 must be binary"
assert mask2.max() <= 1, "mask2 must be binary"
assert mask1.min() >= 0, "mask1 must be binary"
assert mask2.min() >= 0, "mask2 must be binary"
return (mask1 + mask2).clamp(0, 1)
def time_stretch_mask(
x: torch.Tensor,
stretch_factor: int,
):
assert stretch_factor >= 1, "stretch factor must be >= 1"
c_seq_len = x.shape[-1]
x = x.repeat_interleave(stretch_factor, dim=-1)
# trim cz to the original length
x = x[:, :, :c_seq_len]
mask = periodic_mask(x, stretch_factor, width=1)
return mask
def onset_mask(
sig: AudioSignal,
z: torch.Tensor,
interface,
width: int = 1,
):
import librosa
onset_frame_idxs = librosa.onset.onset_detect(
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
hop_length=interface.codec.hop_length,
backtrack=True,
)
if len(onset_frame_idxs) == 0:
print("no onsets detected")
print("onset_frame_idxs", onset_frame_idxs)
print("mask shape", z.shape)
mask = torch.ones_like(z)
for idx in onset_frame_idxs:
mask[:, :, idx-width:idx+width] = 0
return mask
if __name__ == "__main__":
sig = AudioSignal("assets/example.wav")
|