hugo flores garcia
recovering from a gittastrophe
41b9d24
raw
history blame
6.53 kB
from typing import Optional
import torch
from audiotools import AudioSignal
from .util import scalar_to_batch_tensor
def _gamma(r):
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
def _invgamma(y):
if not torch.is_tensor(y):
y = torch.tensor(y)[None]
return 2 * y.acos() / torch.pi
def full_mask(x: torch.Tensor):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
return torch.ones_like(x).long()
def empty_mask(x: torch.Tensor):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
return torch.zeros_like(x).long()
def apply_mask(
x: torch.Tensor,
mask: torch.Tensor,
mask_token: int
):
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
assert ~torch.any(mask > 1), "mask must be binary"
assert ~torch.any(mask < 0), "mask must be binary"
fill_x = torch.full_like(x, mask_token)
x = x * (1 - mask) + fill_x * mask
return x, mask
def random(
x: torch.Tensor,
r: torch.Tensor
):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
if not isinstance(r, torch.Tensor):
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
r = _gamma(r)[:, None, None]
probs = torch.ones_like(x) * r
mask = torch.bernoulli(probs)
mask = mask.round().long()
return mask
def linear_random(
x: torch.Tensor,
r: torch.Tensor,
):
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
if not isinstance(r, torch.Tensor):
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
r = r[:, None, None]
probs = torch.ones_like(x).to(x.device).float()
# expand to batch and codebook dims
probs = probs.expand(x.shape[0], x.shape[1], -1)
probs = probs * r
mask = torch.bernoulli(probs)
mask = mask.round().long()
return mask
def inpaint(x: torch.Tensor,
n_prefix,
n_suffix,
):
assert n_prefix is not None
assert n_suffix is not None
mask = full_mask(x)
# if we have a prefix or suffix, set their mask prob to 0
if n_prefix > 0:
if not isinstance(n_prefix, torch.Tensor):
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
for i, n in enumerate(n_prefix):
if n > 0:
mask[i, :, :n] = 0.0
if n_suffix > 0:
if not isinstance(n_suffix, torch.Tensor):
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
for i, n in enumerate(n_suffix):
if n > 0:
mask[i, :, -n:] = 0.0
return mask
def periodic_mask(x: torch.Tensor,
period: int,width: int = 1,
random_roll=False,
):
mask = full_mask(x)
if period == 0:
return mask
if not isinstance(period, torch.Tensor):
period = scalar_to_batch_tensor(period, x.shape[0])
for i, factor in enumerate(period):
if factor == 0:
continue
for j in range(mask.shape[-1]):
if j % factor == 0:
# figure out how wide the mask should be
j_start = max(0, j - width // 2 )
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
# flip a coin for each position in the mask
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
assert torch.all(j_mask == 1)
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
assert torch.all(j_fill == 0)
# fill
mask[i, :, j_start:j_end] = j_fill
if random_roll:
# add a random offset to the mask
offset = torch.randint(0, period[0], (1,))
mask = torch.roll(mask, offset.item(), dims=-1)
return mask
def codebook_unmask(
mask: torch.Tensor,
n_conditioning_codebooks: int
):
if n_conditioning_codebooks == None:
return mask
# if we have any conditioning codebooks, set their mask to 0
mask = mask.clone()
mask[:, :n_conditioning_codebooks, :] = 0
return mask
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
mask = mask.clone()
mask[:, val1:, :] = 1
# val2 = val2 or val1
# vs = torch.linspace(val1, val2, mask.shape[1])
# for t, v in enumerate(vs):
# v = int(v)
# mask[:, v:, t] = 1
return mask
def mask_and(
mask1: torch.Tensor,
mask2: torch.Tensor
):
assert mask1.shape == mask2.shape, "masks must be same shape"
return torch.min(mask1, mask2)
def dropout(
mask: torch.Tensor,
p: float,
):
assert 0 <= p <= 1, "p must be between 0 and 1"
assert mask.max() <= 1, "mask must be binary"
assert mask.min() >= 0, "mask must be binary"
mask = (~mask.bool()).float()
mask = torch.bernoulli(mask * (1 - p))
mask = ~mask.round().bool()
return mask.long()
def mask_or(
mask1: torch.Tensor,
mask2: torch.Tensor
):
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
assert mask1.max() <= 1, "mask1 must be binary"
assert mask2.max() <= 1, "mask2 must be binary"
assert mask1.min() >= 0, "mask1 must be binary"
assert mask2.min() >= 0, "mask2 must be binary"
return (mask1 + mask2).clamp(0, 1)
def time_stretch_mask(
x: torch.Tensor,
stretch_factor: int,
):
assert stretch_factor >= 1, "stretch factor must be >= 1"
c_seq_len = x.shape[-1]
x = x.repeat_interleave(stretch_factor, dim=-1)
# trim cz to the original length
x = x[:, :, :c_seq_len]
mask = periodic_mask(x, stretch_factor, width=1)
return mask
def onset_mask(
sig: AudioSignal,
z: torch.Tensor,
interface,
width: int = 1,
):
import librosa
onset_frame_idxs = librosa.onset.onset_detect(
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
hop_length=interface.codec.hop_length,
backtrack=True,
)
if len(onset_frame_idxs) == 0:
print("no onsets detected")
print("onset_frame_idxs", onset_frame_idxs)
print("mask shape", z.shape)
mask = torch.ones_like(z)
for idx in onset_frame_idxs:
mask[:, :, idx-width:idx+width] = 0
return mask
if __name__ == "__main__":
sig = AudioSignal("assets/example.wav")