from typing import Optional import torch from audiotools import AudioSignal from .util import scalar_to_batch_tensor def _gamma(r): return (r * torch.pi / 2).cos().clamp(1e-10, 1.0) def _invgamma(y): if not torch.is_tensor(y): y = torch.tensor(y)[None] return 2 * y.acos() / torch.pi def full_mask(x: torch.Tensor): assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" return torch.ones_like(x).long() def empty_mask(x: torch.Tensor): assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" return torch.zeros_like(x).long() def apply_mask( x: torch.Tensor, mask: torch.Tensor, mask_token: int ): assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}" assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}" assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}" assert ~torch.any(mask > 1), "mask must be binary" assert ~torch.any(mask < 0), "mask must be binary" fill_x = torch.full_like(x, mask_token) x = x * (1 - mask) + fill_x * mask return x, mask def random( x: torch.Tensor, r: torch.Tensor ): assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" if not isinstance(r, torch.Tensor): r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device) r = _gamma(r)[:, None, None] probs = torch.ones_like(x) * r mask = torch.bernoulli(probs) mask = mask.round().long() return mask def linear_random( x: torch.Tensor, r: torch.Tensor, ): assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" if not isinstance(r, torch.Tensor): r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float() r = r[:, None, None] probs = torch.ones_like(x).to(x.device).float() # expand to batch and codebook dims probs = probs.expand(x.shape[0], x.shape[1], -1) probs = probs * r mask = torch.bernoulli(probs) mask = mask.round().long() return mask def inpaint(x: torch.Tensor, n_prefix, n_suffix, ): assert n_prefix is not None assert n_suffix is not None mask = full_mask(x) # if we have a prefix or suffix, set their mask prob to 0 if n_prefix > 0: if not isinstance(n_prefix, torch.Tensor): n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device) for i, n in enumerate(n_prefix): if n > 0: mask[i, :, :n] = 0.0 if n_suffix > 0: if not isinstance(n_suffix, torch.Tensor): n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device) for i, n in enumerate(n_suffix): if n > 0: mask[i, :, -n:] = 0.0 return mask def periodic_mask(x: torch.Tensor, period: int,width: int = 1, random_roll=False, ): mask = full_mask(x) if period == 0: return mask if not isinstance(period, torch.Tensor): period = scalar_to_batch_tensor(period, x.shape[0]) for i, factor in enumerate(period): if factor == 0: continue for j in range(mask.shape[-1]): if j % factor == 0: # figure out how wide the mask should be j_start = max(0, j - width // 2 ) j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1 # flip a coin for each position in the mask j_mask = torch.bernoulli(torch.ones(j_end - j_start)) assert torch.all(j_mask == 1) j_fill = torch.ones_like(j_mask) * (1 - j_mask) assert torch.all(j_fill == 0) # fill mask[i, :, j_start:j_end] = j_fill if random_roll: # add a random offset to the mask offset = torch.randint(0, period[0], (1,)) mask = torch.roll(mask, offset.item(), dims=-1) return mask def codebook_unmask( mask: torch.Tensor, n_conditioning_codebooks: int ): if n_conditioning_codebooks == None: return mask # if we have any conditioning codebooks, set their mask to 0 mask = mask.clone() mask[:, :n_conditioning_codebooks, :] = 0 return mask def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None): mask = mask.clone() mask[:, val1:, :] = 1 # val2 = val2 or val1 # vs = torch.linspace(val1, val2, mask.shape[1]) # for t, v in enumerate(vs): # v = int(v) # mask[:, v:, t] = 1 return mask def mask_and( mask1: torch.Tensor, mask2: torch.Tensor ): assert mask1.shape == mask2.shape, "masks must be same shape" return torch.min(mask1, mask2) def dropout( mask: torch.Tensor, p: float, ): assert 0 <= p <= 1, "p must be between 0 and 1" assert mask.max() <= 1, "mask must be binary" assert mask.min() >= 0, "mask must be binary" mask = (~mask.bool()).float() mask = torch.bernoulli(mask * (1 - p)) mask = ~mask.round().bool() return mask.long() def mask_or( mask1: torch.Tensor, mask2: torch.Tensor ): assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}" assert mask1.max() <= 1, "mask1 must be binary" assert mask2.max() <= 1, "mask2 must be binary" assert mask1.min() >= 0, "mask1 must be binary" assert mask2.min() >= 0, "mask2 must be binary" return (mask1 + mask2).clamp(0, 1) def time_stretch_mask( x: torch.Tensor, stretch_factor: int, ): assert stretch_factor >= 1, "stretch factor must be >= 1" c_seq_len = x.shape[-1] x = x.repeat_interleave(stretch_factor, dim=-1) # trim cz to the original length x = x[:, :, :c_seq_len] mask = periodic_mask(x, stretch_factor, width=1) return mask def onset_mask( sig: AudioSignal, z: torch.Tensor, interface, width: int = 1, ): import librosa onset_frame_idxs = librosa.onset.onset_detect( y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate, hop_length=interface.codec.hop_length, backtrack=True, ) if len(onset_frame_idxs) == 0: print("no onsets detected") print("onset_frame_idxs", onset_frame_idxs) print("mask shape", z.shape) mask = torch.ones_like(z) for idx in onset_frame_idxs: mask[:, :, idx-width:idx+width] = 0 return mask if __name__ == "__main__": sig = AudioSignal("assets/example.wav")