Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
from audiocraft.modules.codebooks_patterns import ( | |
DelayedPatternProvider, | |
ParallelPatternProvider, | |
Pattern, | |
UnrolledPatternProvider, | |
) | |
class TestParallelPatternProvider: | |
def test_get_pattern(self, n_q: int, timesteps: int): | |
provider = ParallelPatternProvider(n_q) | |
pattern = provider.get_pattern(timesteps) | |
# + 1 to account for 1st step | |
assert len(pattern.layout) == timesteps + 1 | |
def test_pattern_content(self, n_q: int, timesteps: int): | |
provider = ParallelPatternProvider(n_q) | |
pattern = provider.get_pattern(timesteps) | |
for s, v in enumerate(pattern.layout): | |
for i, code in enumerate(v): | |
assert i == code.q | |
assert code.t == s - 1 # account for the 1st empty step | |
def test_pattern_max_delay(self, n_q: int, timesteps: int): | |
provider = ParallelPatternProvider(n_q) | |
pattern = provider.get_pattern(timesteps) | |
assert pattern.max_delay == 0 | |
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay | |
class TestDelayedPatternProvider: | |
def test_get_pattern(self, n_q: int, timesteps: int): | |
delays = [ | |
list(range(n_q)), | |
[0] + [1] * (n_q - 1), | |
[0] + [4] * (n_q - 1), | |
] | |
for delay in delays: | |
provider = DelayedPatternProvider(n_q, delay) | |
pattern = provider.get_pattern(timesteps) | |
# + 1 to account for 1st step | |
assert len(pattern.layout) == timesteps + max(delay) + 1 | |
def test_pattern_content(self, n_q: int, timesteps: int): | |
provider = DelayedPatternProvider(n_q) | |
pattern = provider.get_pattern(timesteps) | |
for s, v in enumerate(pattern.layout): | |
for i, code in enumerate(v): | |
assert i == code.q | |
assert code.t == max(0, s - code.q - 1) | |
def test_pattern_max_delay(self, timesteps: int, delay: list): | |
provider = DelayedPatternProvider(len(delay), delay) | |
pattern = provider.get_pattern(timesteps) | |
assert pattern.max_delay == max(delay) | |
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay | |
class TestUnrolledPatternProvider: | |
def test_get_pattern(self, timesteps: int, flattening: list, delays: list): | |
n_q = len(flattening) | |
max_delay = max(delays) | |
provider = UnrolledPatternProvider(n_q, flattening, delays) | |
pattern = provider.get_pattern(timesteps) | |
assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay | |
def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list): | |
n_q = len(flattening) | |
max_delay = max(delays) | |
provider = UnrolledPatternProvider(n_q, flattening, delays) | |
pattern = provider.get_pattern(timesteps) | |
assert pattern.max_delay == max_delay | |
class TestPattern: | |
def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): | |
"""Reference method to build the sequence from the pattern without using fancy scatter.""" | |
bs, n_q, T = z.shape | |
z = z.cpu().numpy() | |
assert n_q == pattern.n_q | |
assert T <= pattern.timesteps | |
inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy() | |
inp[:] = special_token | |
for s, v in enumerate(pattern.layout): | |
for (t, q) in v: | |
if t < T: | |
inp[:, q, s] = z[:, q, t] | |
return torch.from_numpy(inp) | |
def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): | |
"""Reference method to revert the sequence from the pattern without using fancy scatter.""" | |
z = z.cpu().numpy() | |
bs, n_q, S = z.shape | |
assert pattern.n_q == n_q | |
inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy() | |
inp[:] = special_token | |
for s, v in enumerate(pattern.layout): | |
for (t, q) in v: | |
if t < pattern.timesteps: | |
inp[:, q, t] = z[:, q, s] | |
return torch.from_numpy(inp) | |
def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float): | |
"""Reference method to revert the logits from the pattern without using fancy scatter.""" | |
z = z.cpu().numpy() | |
bs, card, n_q, S = z.shape | |
assert pattern.n_q == n_q | |
ref_layout = pattern.layout | |
inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy() | |
inp[:] = special_token | |
for s, v in enumerate(ref_layout[1:]): | |
if s < S: | |
for (t, q) in v: | |
if t < pattern.timesteps: | |
inp[:, :, q, t] = z[:, :, q, s] | |
return torch.from_numpy(inp) | |
def _get_pattern_providers(self, n_q: int): | |
pattern_provider_1 = ParallelPatternProvider(n_q) | |
pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q))) | |
pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1)) | |
pattern_provider_4 = UnrolledPatternProvider( | |
n_q, flattening=list(range(n_q)), delays=[0] * n_q | |
) | |
pattern_provider_5 = UnrolledPatternProvider( | |
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q | |
) | |
pattern_provider_6 = UnrolledPatternProvider( | |
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1) | |
) | |
return [ | |
pattern_provider_1, | |
pattern_provider_2, | |
pattern_provider_3, | |
pattern_provider_4, | |
pattern_provider_5, | |
pattern_provider_6, | |
] | |
def test_build_pattern_sequence(self, n_q: int, timesteps: int): | |
bs = 2 | |
card = 256 | |
special_token = card | |
pattern_providers = self._get_pattern_providers(n_q) | |
for pattern_provider in pattern_providers: | |
pattern = pattern_provider.get_pattern(timesteps) | |
# we can correctly build the sequence from the pattern | |
z = torch.randint(0, card, (bs, n_q, timesteps)) | |
ref_res = self.ref_build_pattern_sequence(z, pattern, special_token) | |
res, indexes, mask = pattern.build_pattern_sequence(z, special_token) | |
assert (res == ref_res).float().mean() == 1.0 | |
# expected assertion fails on the number of timesteps | |
invalid_timesteps = [timesteps + 1] | |
if pattern.num_sequence_steps != pattern.timesteps: | |
invalid_timesteps.append(pattern.num_sequence_steps) | |
for i_timesteps in invalid_timesteps: | |
z2 = torch.randint(0, card, (bs, n_q, i_timesteps)) | |
with pytest.raises(AssertionError): | |
pattern.build_pattern_sequence(z2, special_token) | |
# expected assertion fails on the number of codebooks | |
invalid_qs = [0, n_q - 1, n_q + 1] | |
for i_q in invalid_qs: | |
z3 = torch.randint(0, card, (bs, i_q, timesteps)) | |
with pytest.raises(AssertionError): | |
pattern.build_pattern_sequence(z3, special_token) | |
def test_revert_pattern_sequence(self, n_q: int, timesteps: int): | |
bs = 2 | |
card = 256 | |
special_token = card | |
pattern_providers = self._get_pattern_providers(n_q) | |
for pattern_provider in pattern_providers: | |
pattern = pattern_provider.get_pattern(timesteps) | |
# this works assuming previous tests are successful | |
z = torch.randint(0, card, (bs, n_q, timesteps)) | |
s = self.ref_build_pattern_sequence(z, pattern, special_token) | |
ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token) | |
# ensure our reference script retrieve the original sequence | |
assert z.shape == ref_out.shape | |
assert (z == ref_out).float().mean() == 1.0 | |
# now we can test the scatter version | |
out, indexes, mask = pattern.revert_pattern_sequence(s, special_token) | |
assert out.shape == ref_out.shape | |
assert (out == ref_out).float().mean() == 1.0 | |
def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int): | |
bs = 2 | |
special_token = card | |
logits_special_token = float('nan') | |
pattern_providers = self._get_pattern_providers(n_q) | |
for pattern_provider in pattern_providers: | |
pattern = pattern_provider.get_pattern(timesteps) | |
# this works assuming previous tests are successful | |
z = torch.randint(0, card, (bs, n_q, timesteps)) | |
s = self.ref_build_pattern_sequence(z, pattern, special_token) | |
logits = torch.randn((bs, card, n_q, s.shape[-1])) | |
ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token) | |
# ensure our reference script retrieve the original sequence | |
assert ref_out.shape == torch.Size([bs, card, n_q, timesteps]) | |
# now we can test the scatter version | |
out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token) | |
assert out.shape == ref_out.shape | |
assert (out == ref_out).float().mean() == 1.0 | |