fffiloni commited on
Commit
f48c226
·
verified ·
1 Parent(s): 2319c67

Delete audiocraft

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audiocraft/audiocraft/__init__.py +0 -26
  2. audiocraft/audiocraft/__pycache__/__init__.cpython-311.pyc +0 -0
  3. audiocraft/audiocraft/__pycache__/environment.cpython-311.pyc +0 -0
  4. audiocraft/audiocraft/__pycache__/train.cpython-311.pyc +0 -0
  5. audiocraft/audiocraft/adversarial/__init__.py +0 -22
  6. audiocraft/audiocraft/adversarial/__pycache__/__init__.cpython-311.pyc +0 -0
  7. audiocraft/audiocraft/adversarial/__pycache__/losses.cpython-311.pyc +0 -0
  8. audiocraft/audiocraft/adversarial/discriminators/__init__.py +0 -10
  9. audiocraft/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-311.pyc +0 -0
  10. audiocraft/audiocraft/adversarial/discriminators/__pycache__/base.cpython-311.pyc +0 -0
  11. audiocraft/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-311.pyc +0 -0
  12. audiocraft/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-311.pyc +0 -0
  13. audiocraft/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-311.pyc +0 -0
  14. audiocraft/audiocraft/adversarial/discriminators/base.py +0 -34
  15. audiocraft/audiocraft/adversarial/discriminators/mpd.py +0 -106
  16. audiocraft/audiocraft/adversarial/discriminators/msd.py +0 -126
  17. audiocraft/audiocraft/adversarial/discriminators/msstftd.py +0 -134
  18. audiocraft/audiocraft/adversarial/losses.py +0 -228
  19. audiocraft/audiocraft/data/__init__.py +0 -10
  20. audiocraft/audiocraft/data/__pycache__/__init__.cpython-311.pyc +0 -0
  21. audiocraft/audiocraft/data/__pycache__/audio.cpython-311.pyc +0 -0
  22. audiocraft/audiocraft/data/__pycache__/audio_dataset.cpython-311.pyc +0 -0
  23. audiocraft/audiocraft/data/__pycache__/audio_utils.cpython-311.pyc +0 -0
  24. audiocraft/audiocraft/data/__pycache__/btc_chords.cpython-311.pyc +0 -0
  25. audiocraft/audiocraft/data/__pycache__/chords.cpython-311.pyc +0 -0
  26. audiocraft/audiocraft/data/__pycache__/info_audio_dataset.cpython-311.pyc +0 -0
  27. audiocraft/audiocraft/data/__pycache__/music_dataset.cpython-311.pyc +0 -0
  28. audiocraft/audiocraft/data/__pycache__/sound_dataset.cpython-311.pyc +0 -0
  29. audiocraft/audiocraft/data/__pycache__/zip.cpython-311.pyc +0 -0
  30. audiocraft/audiocraft/data/audio.py +0 -257
  31. audiocraft/audiocraft/data/audio_dataset.py +0 -614
  32. audiocraft/audiocraft/data/audio_utils.py +0 -385
  33. audiocraft/audiocraft/data/btc_chords.py +0 -524
  34. audiocraft/audiocraft/data/chords.py +0 -524
  35. audiocraft/audiocraft/data/info_audio_dataset.py +0 -110
  36. audiocraft/audiocraft/data/music_dataset.py +0 -349
  37. audiocraft/audiocraft/data/sound_dataset.py +0 -330
  38. audiocraft/audiocraft/data/zip.py +0 -76
  39. audiocraft/audiocraft/environment.py +0 -176
  40. audiocraft/audiocraft/grids/__init__.py +0 -6
  41. audiocraft/audiocraft/grids/_base_explorers.py +0 -80
  42. audiocraft/audiocraft/grids/audiogen/__init__.py +0 -6
  43. audiocraft/audiocraft/grids/audiogen/audiogen_base_16khz.py +0 -23
  44. audiocraft/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +0 -68
  45. audiocraft/audiocraft/grids/compression/__init__.py +0 -6
  46. audiocraft/audiocraft/grids/compression/_explorers.py +0 -55
  47. audiocraft/audiocraft/grids/compression/debug.py +0 -31
  48. audiocraft/audiocraft/grids/compression/encodec_audiogen_16khz.py +0 -29
  49. audiocraft/audiocraft/grids/compression/encodec_base_24khz.py +0 -28
  50. audiocraft/audiocraft/grids/compression/encodec_musicgen_32khz.py +0 -34
audiocraft/audiocraft/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """
7
- AudioCraft is a general framework for training audio generative models.
8
- At the moment we provide the training code for:
9
-
10
- - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11
- text-to-music and melody+text autoregressive generative model.
12
- For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13
- `audiocraft.models.musicgen.MusicGen`.
14
- - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15
- text-to-general-audio generative model.
16
- - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17
- neural audio codec which provides an excellent tokenizer for autoregressive language models.
18
- See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19
- - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20
- improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21
- """
22
-
23
- # flake8: noqa
24
- from . import data, modules, models
25
-
26
- __version__ = '1.0.0'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.29 kB)
 
audiocraft/audiocraft/__pycache__/environment.cpython-311.pyc DELETED
Binary file (10.5 kB)
 
audiocraft/audiocraft/__pycache__/train.cpython-311.pyc DELETED
Binary file (9.52 kB)
 
audiocraft/audiocraft/adversarial/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Adversarial losses and discriminator architectures."""
7
-
8
- # flake8: noqa
9
- from .discriminators import (
10
- MultiPeriodDiscriminator,
11
- MultiScaleDiscriminator,
12
- MultiScaleSTFTDiscriminator
13
- )
14
- from .losses import (
15
- AdversarialLoss,
16
- AdvLossType,
17
- get_adv_criterion,
18
- get_fake_criterion,
19
- get_real_criterion,
20
- FeatLossType,
21
- FeatureMatchingLoss
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (740 Bytes)
 
audiocraft/audiocraft/adversarial/__pycache__/losses.cpython-311.pyc DELETED
Binary file (15.9 kB)
 
audiocraft/audiocraft/adversarial/discriminators/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # flake8: noqa
8
- from .mpd import MultiPeriodDiscriminator
9
- from .msd import MultiScaleDiscriminator
10
- from .msstftd import MultiScaleSTFTDiscriminator
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (411 Bytes)
 
audiocraft/audiocraft/adversarial/discriminators/__pycache__/base.cpython-311.pyc DELETED
Binary file (1.87 kB)
 
audiocraft/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-311.pyc DELETED
Binary file (7.01 kB)
 
audiocraft/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-311.pyc DELETED
Binary file (8.88 kB)
 
audiocraft/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-311.pyc DELETED
Binary file (9.98 kB)
 
audiocraft/audiocraft/adversarial/discriminators/base.py DELETED
@@ -1,34 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from abc import ABC, abstractmethod
8
- import typing as tp
9
-
10
- import torch
11
- import torch.nn as nn
12
-
13
-
14
- FeatureMapType = tp.List[torch.Tensor]
15
- LogitsType = torch.Tensor
16
- MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17
-
18
-
19
- class MultiDiscriminator(ABC, nn.Module):
20
- """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21
- """
22
- def __init__(self):
23
- super().__init__()
24
-
25
- @abstractmethod
26
- def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27
- ...
28
-
29
- @property
30
- @abstractmethod
31
- def num_discriminators(self) -> int:
32
- """Number of discriminators.
33
- """
34
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/discriminators/mpd.py DELETED
@@ -1,106 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import typing as tp
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
- from ...modules import NormConv2d
14
- from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
-
16
-
17
- def get_padding(kernel_size: int, dilation: int = 1) -> int:
18
- return int((kernel_size * dilation - dilation) / 2)
19
-
20
-
21
- class PeriodDiscriminator(nn.Module):
22
- """Period sub-discriminator.
23
-
24
- Args:
25
- period (int): Period between samples of audio.
26
- in_channels (int): Number of input channels.
27
- out_channels (int): Number of output channels.
28
- n_layers (int): Number of convolutional layers.
29
- kernel_sizes (list of int): Kernel sizes for convolutions.
30
- stride (int): Stride for convolutions.
31
- filters (int): Initial number of filters in convolutions.
32
- filters_scale (int): Multiplier of number of filters as we increase depth.
33
- max_filters (int): Maximum number of filters.
34
- norm (str): Normalization method.
35
- activation (str): Activation function.
36
- activation_params (dict): Parameters to provide to the activation function.
37
- """
38
- def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39
- n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40
- filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41
- norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
- activation_params: dict = {'negative_slope': 0.2}):
43
- super().__init__()
44
- self.period = period
45
- self.n_layers = n_layers
46
- self.activation = getattr(torch.nn, activation)(**activation_params)
47
- self.convs = nn.ModuleList()
48
- in_chs = in_channels
49
- for i in range(self.n_layers):
50
- out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51
- eff_stride = 1 if i == self.n_layers - 1 else stride
52
- self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53
- padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54
- in_chs = out_chs
55
- self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56
- padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57
-
58
- def forward(self, x: torch.Tensor):
59
- fmap = []
60
- # 1d to 2d
61
- b, c, t = x.shape
62
- if t % self.period != 0: # pad first
63
- n_pad = self.period - (t % self.period)
64
- x = F.pad(x, (0, n_pad), 'reflect')
65
- t = t + n_pad
66
- x = x.view(b, c, t // self.period, self.period)
67
-
68
- for conv in self.convs:
69
- x = conv(x)
70
- x = self.activation(x)
71
- fmap.append(x)
72
- x = self.conv_post(x)
73
- fmap.append(x)
74
- # x = torch.flatten(x, 1, -1)
75
-
76
- return x, fmap
77
-
78
-
79
- class MultiPeriodDiscriminator(MultiDiscriminator):
80
- """Multi-Period (MPD) Discriminator.
81
-
82
- Args:
83
- in_channels (int): Number of input channels.
84
- out_channels (int): Number of output channels.
85
- periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86
- **kwargs: Additional args for `PeriodDiscriminator`
87
- """
88
- def __init__(self, in_channels: int = 1, out_channels: int = 1,
89
- periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90
- super().__init__()
91
- self.discriminators = nn.ModuleList([
92
- PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93
- ])
94
-
95
- @property
96
- def num_discriminators(self):
97
- return len(self.discriminators)
98
-
99
- def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100
- logits = []
101
- fmaps = []
102
- for disc in self.discriminators:
103
- logit, fmap = disc(x)
104
- logits.append(logit)
105
- fmaps.append(fmap)
106
- return logits, fmaps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/discriminators/msd.py DELETED
@@ -1,126 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import typing as tp
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
-
13
- from ...modules import NormConv1d
14
- from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
-
16
-
17
- class ScaleDiscriminator(nn.Module):
18
- """Waveform sub-discriminator.
19
-
20
- Args:
21
- in_channels (int): Number of input channels.
22
- out_channels (int): Number of output channels.
23
- kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
24
- filters (int): Number of initial filters for convolutions.
25
- max_filters (int): Maximum number of filters.
26
- downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
27
- inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
28
- groups (Sequence[int] or None): Groups for inner convolutions.
29
- strides (Sequence[int] or None): Strides for inner convolutions.
30
- paddings (Sequence[int] or None): Paddings for inner convolutions.
31
- norm (str): Normalization method.
32
- activation (str): Activation function.
33
- activation_params (dict): Parameters to provide to the activation function.
34
- pad (str): Padding for initial convolution.
35
- pad_params (dict): Parameters to provide to the padding module.
36
- """
37
- def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
38
- filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
39
- inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
40
- strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
41
- norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
- activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
43
- pad_params: dict = {}):
44
- super().__init__()
45
- assert len(kernel_sizes) == 2
46
- assert kernel_sizes[0] % 2 == 1
47
- assert kernel_sizes[1] % 2 == 1
48
- assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
49
- assert (groups is None or len(groups) == len(downsample_scales))
50
- assert (strides is None or len(strides) == len(downsample_scales))
51
- assert (paddings is None or len(paddings) == len(downsample_scales))
52
- self.activation = getattr(torch.nn, activation)(**activation_params)
53
- self.convs = nn.ModuleList()
54
- self.convs.append(
55
- nn.Sequential(
56
- getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
57
- NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
58
- )
59
- )
60
-
61
- in_chs = filters
62
- for i, downsample_scale in enumerate(downsample_scales):
63
- out_chs = min(in_chs * downsample_scale, max_filters)
64
- default_kernel_size = downsample_scale * 10 + 1
65
- default_stride = downsample_scale
66
- default_padding = (default_kernel_size - 1) // 2
67
- default_groups = in_chs // 4
68
- self.convs.append(
69
- NormConv1d(in_chs, out_chs,
70
- kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
71
- stride=strides[i] if strides else default_stride,
72
- groups=groups[i] if groups else default_groups,
73
- padding=paddings[i] if paddings else default_padding,
74
- norm=norm))
75
- in_chs = out_chs
76
-
77
- out_chs = min(in_chs * 2, max_filters)
78
- self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
79
- padding=(kernel_sizes[0] - 1) // 2, norm=norm))
80
- self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
81
- padding=(kernel_sizes[1] - 1) // 2, norm=norm)
82
-
83
- def forward(self, x: torch.Tensor):
84
- fmap = []
85
- for layer in self.convs:
86
- x = layer(x)
87
- x = self.activation(x)
88
- fmap.append(x)
89
- x = self.conv_post(x)
90
- fmap.append(x)
91
- # x = torch.flatten(x, 1, -1)
92
- return x, fmap
93
-
94
-
95
- class MultiScaleDiscriminator(MultiDiscriminator):
96
- """Multi-Scale (MSD) Discriminator,
97
-
98
- Args:
99
- in_channels (int): Number of input channels.
100
- out_channels (int): Number of output channels.
101
- downsample_factor (int): Downsampling factor between the different scales.
102
- scale_norms (Sequence[str]): Normalization for each sub-discriminator.
103
- **kwargs: Additional args for ScaleDiscriminator.
104
- """
105
- def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
106
- scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
107
- super().__init__()
108
- self.discriminators = nn.ModuleList([
109
- ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
110
- ])
111
- self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
112
-
113
- @property
114
- def num_discriminators(self):
115
- return len(self.discriminators)
116
-
117
- def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
118
- logits = []
119
- fmaps = []
120
- for i, disc in enumerate(self.discriminators):
121
- if i != 0:
122
- self.downsample(x)
123
- logit, fmap = disc(x)
124
- logits.append(logit)
125
- fmaps.append(fmap)
126
- return logits, fmaps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/discriminators/msstftd.py DELETED
@@ -1,134 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import typing as tp
8
-
9
- import torchaudio
10
- import torch
11
- from torch import nn
12
- from einops import rearrange
13
-
14
- from ...modules import NormConv2d
15
- from .base import MultiDiscriminator, MultiDiscriminatorOutputType
16
-
17
-
18
- def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
19
- return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
20
-
21
-
22
- class DiscriminatorSTFT(nn.Module):
23
- """STFT sub-discriminator.
24
-
25
- Args:
26
- filters (int): Number of filters in convolutions.
27
- in_channels (int): Number of input channels.
28
- out_channels (int): Number of output channels.
29
- n_fft (int): Size of FFT for each scale.
30
- hop_length (int): Length of hop between STFT windows for each scale.
31
- kernel_size (tuple of int): Inner Conv2d kernel sizes.
32
- stride (tuple of int): Inner Conv2d strides.
33
- dilations (list of int): Inner Conv2d dilation on the time dimension.
34
- win_length (int): Window size for each scale.
35
- normalized (bool): Whether to normalize by magnitude after stft.
36
- norm (str): Normalization method.
37
- activation (str): Activation function.
38
- activation_params (dict): Parameters to provide to the activation function.
39
- growth (int): Growth factor for the filters.
40
- """
41
- def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
42
- n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
43
- filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
44
- stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
45
- activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
46
- super().__init__()
47
- assert len(kernel_size) == 2
48
- assert len(stride) == 2
49
- self.filters = filters
50
- self.in_channels = in_channels
51
- self.out_channels = out_channels
52
- self.n_fft = n_fft
53
- self.hop_length = hop_length
54
- self.win_length = win_length
55
- self.normalized = normalized
56
- self.activation = getattr(torch.nn, activation)(**activation_params)
57
- self.spec_transform = torchaudio.transforms.Spectrogram(
58
- n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
59
- normalized=self.normalized, center=False, pad_mode=None, power=None)
60
- spec_channels = 2 * self.in_channels
61
- self.convs = nn.ModuleList()
62
- self.convs.append(
63
- NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
64
- )
65
- in_chs = min(filters_scale * self.filters, max_filters)
66
- for i, dilation in enumerate(dilations):
67
- out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
68
- self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
69
- dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
70
- norm=norm))
71
- in_chs = out_chs
72
- out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
73
- self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
74
- padding=get_2d_padding((kernel_size[0], kernel_size[0])),
75
- norm=norm))
76
- self.conv_post = NormConv2d(out_chs, self.out_channels,
77
- kernel_size=(kernel_size[0], kernel_size[0]),
78
- padding=get_2d_padding((kernel_size[0], kernel_size[0])),
79
- norm=norm)
80
-
81
- def forward(self, x: torch.Tensor):
82
- fmap = []
83
- z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
84
- z = torch.cat([z.real, z.imag], dim=1)
85
- z = rearrange(z, 'b c w t -> b c t w')
86
- for i, layer in enumerate(self.convs):
87
- z = layer(z)
88
- z = self.activation(z)
89
- fmap.append(z)
90
- z = self.conv_post(z)
91
- return z, fmap
92
-
93
-
94
- class MultiScaleSTFTDiscriminator(MultiDiscriminator):
95
- """Multi-Scale STFT (MS-STFT) discriminator.
96
-
97
- Args:
98
- filters (int): Number of filters in convolutions.
99
- in_channels (int): Number of input channels.
100
- out_channels (int): Number of output channels.
101
- sep_channels (bool): Separate channels to distinct samples for stereo support.
102
- n_ffts (Sequence[int]): Size of FFT for each scale.
103
- hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
104
- win_lengths (Sequence[int]): Window size for each scale.
105
- **kwargs: Additional args for STFTDiscriminator.
106
- """
107
- def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
108
- n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
109
- win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
110
- super().__init__()
111
- assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
112
- self.sep_channels = sep_channels
113
- self.discriminators = nn.ModuleList([
114
- DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
115
- n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
116
- for i in range(len(n_ffts))
117
- ])
118
-
119
- @property
120
- def num_discriminators(self):
121
- return len(self.discriminators)
122
-
123
- def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
124
- B, C, T = x.shape
125
- return x.view(-1, 1, T)
126
-
127
- def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
128
- logits = []
129
- fmaps = []
130
- for disc in self.discriminators:
131
- logit, fmap = disc(x)
132
- logits.append(logit)
133
- fmaps.append(fmap)
134
- return logits, fmaps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/adversarial/losses.py DELETED
@@ -1,228 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
- """
10
-
11
- import typing as tp
12
-
13
- import flashy
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
-
18
-
19
- ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
-
21
-
22
- AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
- FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
-
25
-
26
- class AdversarialLoss(nn.Module):
27
- """Adversary training wrapper.
28
-
29
- Args:
30
- adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
- We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
- where the first item is a list of logits and the second item is a list of feature maps.
33
- optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
- loss (AdvLossType): Loss function for generator training.
35
- loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
- loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
- loss_feat (FeatLossType): Feature matching loss function for generator training.
38
- normalize (bool): Whether to normalize by number of sub-discriminators.
39
-
40
- Example of usage:
41
- adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
- for real in loader:
43
- noise = torch.randn(...)
44
- fake = model(noise)
45
- adv_loss.train_adv(fake, real)
46
- loss, _ = adv_loss(fake, real)
47
- loss.backward()
48
- """
49
- def __init__(self,
50
- adversary: nn.Module,
51
- optimizer: torch.optim.Optimizer,
52
- loss: AdvLossType,
53
- loss_real: AdvLossType,
54
- loss_fake: AdvLossType,
55
- loss_feat: tp.Optional[FeatLossType] = None,
56
- normalize: bool = True):
57
- super().__init__()
58
- self.adversary: nn.Module = adversary
59
- flashy.distrib.broadcast_model(self.adversary)
60
- self.optimizer = optimizer
61
- self.loss = loss
62
- self.loss_real = loss_real
63
- self.loss_fake = loss_fake
64
- self.loss_feat = loss_feat
65
- self.normalize = normalize
66
-
67
- def _save_to_state_dict(self, destination, prefix, keep_vars):
68
- # Add the optimizer state dict inside our own.
69
- super()._save_to_state_dict(destination, prefix, keep_vars)
70
- destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
- return destination
72
-
73
- def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
- # Load optimizer state.
75
- self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
- super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
-
78
- def get_adversary_pred(self, x):
79
- """Run adversary model, validating expected output format."""
80
- logits, fmaps = self.adversary(x)
81
- assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
- f'Expecting a list of tensors as logits but {type(logits)} found.'
83
- assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
- for fmap in fmaps:
85
- assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
- f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
- return logits, fmaps
88
-
89
- def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
- """Train the adversary with the given fake and real example.
91
-
92
- We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
- The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
-
95
- This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
- and call the optimizer.
97
- """
98
- loss = torch.tensor(0., device=fake.device)
99
- all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
- all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
- n_sub_adversaries = len(all_logits_fake_is_fake)
102
- for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
- loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
-
105
- if self.normalize:
106
- loss /= n_sub_adversaries
107
-
108
- self.optimizer.zero_grad()
109
- with flashy.distrib.eager_sync_model(self.adversary):
110
- loss.backward()
111
- self.optimizer.step()
112
-
113
- return loss
114
-
115
- def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
- """Return the loss for the generator, i.e. trying to fool the adversary,
117
- and feature matching loss if provided.
118
- """
119
- adv = torch.tensor(0., device=fake.device)
120
- feat = torch.tensor(0., device=fake.device)
121
- with flashy.utils.readonly(self.adversary):
122
- all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
- all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
- n_sub_adversaries = len(all_logits_fake_is_fake)
125
- for logit_fake_is_fake in all_logits_fake_is_fake:
126
- adv += self.loss(logit_fake_is_fake)
127
- if self.loss_feat:
128
- for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
- feat += self.loss_feat(fmap_fake, fmap_real)
130
-
131
- if self.normalize:
132
- adv /= n_sub_adversaries
133
- feat /= n_sub_adversaries
134
-
135
- return adv, feat
136
-
137
-
138
- def get_adv_criterion(loss_type: str) -> tp.Callable:
139
- assert loss_type in ADVERSARIAL_LOSSES
140
- if loss_type == 'mse':
141
- return mse_loss
142
- elif loss_type == 'hinge':
143
- return hinge_loss
144
- elif loss_type == 'hinge2':
145
- return hinge2_loss
146
- raise ValueError('Unsupported loss')
147
-
148
-
149
- def get_fake_criterion(loss_type: str) -> tp.Callable:
150
- assert loss_type in ADVERSARIAL_LOSSES
151
- if loss_type == 'mse':
152
- return mse_fake_loss
153
- elif loss_type in ['hinge', 'hinge2']:
154
- return hinge_fake_loss
155
- raise ValueError('Unsupported loss')
156
-
157
-
158
- def get_real_criterion(loss_type: str) -> tp.Callable:
159
- assert loss_type in ADVERSARIAL_LOSSES
160
- if loss_type == 'mse':
161
- return mse_real_loss
162
- elif loss_type in ['hinge', 'hinge2']:
163
- return hinge_real_loss
164
- raise ValueError('Unsupported loss')
165
-
166
-
167
- def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
- return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
-
170
-
171
- def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
- return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
-
174
-
175
- def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
- return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
-
178
-
179
- def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
- return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
-
182
-
183
- def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
- if x.numel() == 0:
185
- return torch.tensor([0.0], device=x.device)
186
- return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
-
188
-
189
- def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
- if x.numel() == 0:
191
- return torch.tensor([0.0], device=x.device)
192
- return -x.mean()
193
-
194
-
195
- def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
- if x.numel() == 0:
197
- return torch.tensor([0.0])
198
- return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
-
200
-
201
- class FeatureMatchingLoss(nn.Module):
202
- """Feature matching loss for adversarial training.
203
-
204
- Args:
205
- loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
- normalize (bool): Whether to normalize the loss.
207
- by number of feature maps.
208
- """
209
- def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
- super().__init__()
211
- self.loss = loss
212
- self.normalize = normalize
213
-
214
- def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
- assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
- feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
- feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
- n_fmaps = 0
219
- for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
- assert feat_fake.shape == feat_real.shape
221
- n_fmaps += 1
222
- feat_loss += self.loss(feat_fake, feat_real)
223
- feat_scale += torch.mean(torch.abs(feat_real))
224
-
225
- if self.normalize:
226
- feat_loss /= n_fmaps
227
-
228
- return feat_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Audio loading and writing support. Datasets for raw audio
7
- or also including some metadata."""
8
-
9
- # flake8: noqa
10
- from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset, btc_chords
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (493 Bytes)
 
audiocraft/audiocraft/data/__pycache__/audio.cpython-311.pyc DELETED
Binary file (14.9 kB)
 
audiocraft/audiocraft/data/__pycache__/audio_dataset.cpython-311.pyc DELETED
Binary file (36.7 kB)
 
audiocraft/audiocraft/data/__pycache__/audio_utils.cpython-311.pyc DELETED
Binary file (21.4 kB)
 
audiocraft/audiocraft/data/__pycache__/btc_chords.cpython-311.pyc DELETED
Binary file (23.4 kB)
 
audiocraft/audiocraft/data/__pycache__/chords.cpython-311.pyc DELETED
Binary file (23.4 kB)
 
audiocraft/audiocraft/data/__pycache__/info_audio_dataset.cpython-311.pyc DELETED
Binary file (7.63 kB)
 
audiocraft/audiocraft/data/__pycache__/music_dataset.cpython-311.pyc DELETED
Binary file (21.8 kB)
 
audiocraft/audiocraft/data/__pycache__/sound_dataset.cpython-311.pyc DELETED
Binary file (18.8 kB)
 
audiocraft/audiocraft/data/__pycache__/zip.cpython-311.pyc DELETED
Binary file (3.68 kB)
 
audiocraft/audiocraft/data/audio.py DELETED
@@ -1,257 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Audio IO methods are defined in this module (info, read, write),
9
- We rely on av library for faster read when possible, otherwise on torchaudio.
10
- """
11
-
12
- from dataclasses import dataclass
13
- from pathlib import Path
14
- import logging
15
- import typing as tp
16
-
17
- import numpy as np
18
- import soundfile
19
- import torch
20
- from torch.nn import functional as F
21
- import torchaudio as ta
22
-
23
- import av
24
-
25
- from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26
-
27
-
28
- _av_initialized = False
29
-
30
-
31
- def _init_av():
32
- global _av_initialized
33
- if _av_initialized:
34
- return
35
- logger = logging.getLogger('libav.mp3')
36
- logger.setLevel(logging.ERROR)
37
- _av_initialized = True
38
-
39
-
40
- @dataclass(frozen=True)
41
- class AudioFileInfo:
42
- sample_rate: int
43
- duration: float
44
- channels: int
45
-
46
-
47
- def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
- _init_av()
49
- with av.open(str(filepath)) as af:
50
- stream = af.streams.audio[0]
51
- sample_rate = stream.codec_context.sample_rate
52
- duration = float(stream.duration * stream.time_base)
53
- channels = stream.channels
54
- return AudioFileInfo(sample_rate, duration, channels)
55
-
56
-
57
- def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58
- info = soundfile.info(filepath)
59
- return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
-
61
-
62
- def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
- # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
- filepath = Path(filepath)
65
- if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66
- # ffmpeg has some weird issue with flac.
67
- return _soundfile_info(filepath)
68
- else:
69
- return _av_info(filepath)
70
-
71
-
72
- def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
- """FFMPEG-based audio file reading using PyAV bindings.
74
- Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
-
76
- Args:
77
- filepath (str or Path): Path to audio file to read.
78
- seek_time (float): Time at which to start reading in the file.
79
- duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
- Returns:
81
- tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
- """
83
- _init_av()
84
- with av.open(str(filepath)) as af:
85
- stream = af.streams.audio[0]
86
- sr = stream.codec_context.sample_rate
87
- num_frames = int(sr * duration) if duration >= 0 else -1
88
- frame_offset = int(sr * seek_time)
89
- # we need a small negative offset otherwise we get some edge artifact
90
- # from the mp3 decoder.
91
- af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92
- frames = []
93
- length = 0
94
- for frame in af.decode(streams=stream.index):
95
- current_offset = int(frame.rate * frame.pts * frame.time_base)
96
- strip = max(0, frame_offset - current_offset)
97
- buf = torch.from_numpy(frame.to_ndarray())
98
- if buf.shape[0] != stream.channels:
99
- buf = buf.view(-1, stream.channels).t()
100
- buf = buf[:, strip:]
101
- frames.append(buf)
102
- length += buf.shape[1]
103
- if num_frames > 0 and length >= num_frames:
104
- break
105
- assert frames
106
- # If the above assert fails, it is likely because we seeked past the end of file point,
107
- # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108
- # This will need proper debugging, in due time.
109
- wav = torch.cat(frames, dim=1)
110
- assert wav.shape[0] == stream.channels
111
- if num_frames > 0:
112
- wav = wav[:, :num_frames]
113
- return f32_pcm(wav), sr
114
-
115
-
116
- def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
- duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
- """Read audio by picking the most appropriate backend tool based on the audio format.
119
-
120
- Args:
121
- filepath (str or Path): Path to audio file to read.
122
- seek_time (float): Time at which to start reading in the file.
123
- duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
- pad (bool): Pad output audio if not reaching expected duration.
125
- Returns:
126
- tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
127
- """
128
- fp = Path(filepath)
129
- if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130
- # There is some bug with ffmpeg and reading flac
131
- info = _soundfile_info(filepath)
132
- frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133
- frame_offset = int(seek_time * info.sample_rate)
134
- wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135
- assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136
- wav = torch.from_numpy(wav).t().contiguous()
137
- if len(wav.shape) == 1:
138
- wav = torch.unsqueeze(wav, 0)
139
- elif (
140
- fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141
- and duration <= 0 and seek_time == 0
142
- ):
143
- # Torchaudio is faster if we load an entire file at once.
144
- wav, sr = ta.load(fp)
145
- else:
146
- wav, sr = _av_read(filepath, seek_time, duration)
147
- if pad and duration > 0:
148
- expected_frames = int(duration * sr)
149
- wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150
- return wav, sr
151
-
152
-
153
- def audio_write(stem_name: tp.Union[str, Path],
154
- wav: torch.Tensor, sample_rate: int,
155
- format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
- rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
- loudness_compressor: bool = False,
159
- log_clipping: bool = True, make_parent_dir: bool = True,
160
- add_suffix: bool = True) -> Path:
161
- """Convenience function for saving audio to disk. Returns the filename the audio was written to.
162
-
163
- Args:
164
- stem_name (str or Path): Filename without extension which will be added automatically.
165
- format (str): Either "wav" or "mp3".
166
- mp3_rate (int): kbps when using mp3s.
167
- normalize (bool): if `True` (default), normalizes according to the prescribed
168
- strategy (see after). If `False`, the strategy is only used in case clipping
169
- would happen.
170
- strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
171
- i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
172
- with extra headroom to avoid clipping. 'clip' just clips.
173
- peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
174
- rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
175
- than the `peak_clip` one to avoid further clipping.
176
- loudness_headroom_db (float): Target loudness for loudness normalization.
177
- loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
178
- when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
179
- occurs despite strategy (only for 'rms').
180
- make_parent_dir (bool): Make parent directory if it doesn't exist.
181
- Returns:
182
- Path: Path of the saved audio.
183
- """
184
- assert wav.dtype.is_floating_point, "wav is not floating point"
185
- if wav.dim() == 1:
186
- wav = wav[None]
187
- elif wav.dim() > 2:
188
- raise ValueError("Input wav should be at most 2 dimension.")
189
- assert wav.isfinite().all()
190
- wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
191
- rms_headroom_db, loudness_headroom_db, loudness_compressor,
192
- log_clipping=log_clipping, sample_rate=sample_rate,
193
- stem_name=str(stem_name))
194
- kwargs: dict = {}
195
- if format == 'mp3':
196
- suffix = '.mp3'
197
- kwargs.update({"compression": mp3_rate})
198
- elif format == 'wav':
199
- wav = i16_pcm(wav)
200
- suffix = '.wav'
201
- kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
202
- else:
203
- raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
204
- if not add_suffix:
205
- suffix = ''
206
- path = Path(str(stem_name) + suffix)
207
- if make_parent_dir:
208
- path.parent.mkdir(exist_ok=True, parents=True)
209
- try:
210
- ta.save(path, wav, sample_rate, **kwargs)
211
- except Exception:
212
- if path.exists():
213
- # we do not want to leave half written files around.
214
- path.unlink()
215
- raise
216
- return path
217
-
218
- def audio_postproc(wav: torch.Tensor, sample_rate: int, normalize: bool = True,
219
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
220
- rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
221
- loudness_compressor: bool = False, log_clipping: bool = True) -> Path:
222
- """Convenience function for saving audio to disk. Returns the filename the audio was written to.
223
-
224
- Args:
225
- wav (torch.Tensor): Audio data to save.
226
- sample_rate (int): Sample rate of audio data.
227
- format (str): Either "wav" or "mp3".
228
- mp3_rate (int): kbps when using mp3s.
229
- normalize (bool): if `True` (default), normalizes according to the prescribed
230
- strategy (see after). If `False`, the strategy is only used in case clipping
231
- would happen.
232
- strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
233
- i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
234
- with extra headroom to avoid clipping. 'clip' just clips.
235
- peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
236
- rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
237
- than the `peak_clip` one to avoid further clipping.
238
- loudness_headroom_db (float): Target loudness for loudness normalization.
239
- loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
240
- when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
241
- occurs despite strategy (only for 'rms').
242
- make_parent_dir (bool): Make parent directory if it doesn't exist.
243
- Returns:
244
- Path: Path of the saved audio.
245
- """
246
- assert wav.dtype.is_floating_point, "wav is not floating point"
247
- if wav.dim() == 1:
248
- wav = wav[None]
249
- elif wav.dim() > 2:
250
- raise ValueError("Input wav should be at most 2 dimension.")
251
- assert wav.isfinite().all()
252
- wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
253
- rms_headroom_db, loudness_headroom_db, loudness_compressor,
254
- log_clipping=log_clipping, sample_rate=sample_rate,
255
- stem_name=None)
256
-
257
- return wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/audio_dataset.py DELETED
@@ -1,614 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """AudioDataset support. In order to handle a larger number of files
7
- without having to scan again the folders, we precompute some metadata
8
- (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
- """
10
- import argparse
11
- import copy
12
- from concurrent.futures import ThreadPoolExecutor, Future
13
- from dataclasses import dataclass, fields
14
- from contextlib import ExitStack
15
- from functools import lru_cache
16
- import gzip
17
- import json
18
- import logging
19
- import os
20
- from pathlib import Path
21
- import random
22
- import sys
23
- import typing as tp
24
-
25
- import torch
26
- import torch.nn.functional as F
27
-
28
- from .audio import audio_read, audio_info
29
- from .audio_utils import convert_audio
30
- from .zip import PathInZip
31
-
32
- try:
33
- import dora
34
- except ImportError:
35
- dora = None # type: ignore
36
-
37
-
38
- @dataclass(order=True)
39
- class BaseInfo:
40
-
41
- @classmethod
42
- def _dict2fields(cls, dictionary: dict):
43
- return {
44
- field.name: dictionary[field.name]
45
- for field in fields(cls) if field.name in dictionary
46
- }
47
-
48
- @classmethod
49
- def from_dict(cls, dictionary: dict):
50
- _dictionary = cls._dict2fields(dictionary)
51
- return cls(**_dictionary)
52
-
53
- def to_dict(self):
54
- return {
55
- field.name: self.__getattribute__(field.name)
56
- for field in fields(self)
57
- }
58
-
59
-
60
- @dataclass(order=True)
61
- class AudioMeta(BaseInfo):
62
- path: str
63
- duration: float
64
- sample_rate: int
65
- bpm: float
66
- # meter: int
67
- amplitude: tp.Optional[float] = None
68
- weight: tp.Optional[float] = None
69
- phr_start: tp.List[tp.Optional[float]] = None
70
- # info_path is used to load additional information about the audio file that is stored in zip files.
71
- info_path: tp.Optional[PathInZip] = None
72
-
73
- @classmethod
74
- def from_dict(cls, dictionary: dict):
75
- base = cls._dict2fields(dictionary)
76
- if 'info_path' in base and base['info_path'] is not None:
77
- base['info_path'] = PathInZip(base['info_path'])
78
- return cls(**base)
79
-
80
- def to_dict(self):
81
- d = super().to_dict()
82
- if d['info_path'] is not None:
83
- d['info_path'] = str(d['info_path'])
84
- return d
85
-
86
-
87
- @dataclass(order=True)
88
- class SegmentInfo(BaseInfo):
89
- meta: AudioMeta
90
- seek_time: float
91
- # The following values are given once the audio is processed, e.g.
92
- # at the target sample rate and target number of channels.
93
- n_frames: int # actual number of frames without padding
94
- total_frames: int # total number of frames, padding included
95
- sample_rate: int # actual sample rate
96
- channels: int # number of audio channels.
97
-
98
-
99
- DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
100
-
101
- logger = logging.getLogger(__name__)
102
-
103
-
104
- def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
105
- """AudioMeta from a path to an audio file.
106
-
107
- Args:
108
- file_path (str): Resolved path of valid audio file.
109
- minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
110
- Returns:
111
- AudioMeta: Audio file path and its metadata.
112
- """
113
- info = audio_info(file_path)
114
- amplitude: tp.Optional[float] = None
115
- if not minimal:
116
- wav, sr = audio_read(file_path)
117
- amplitude = wav.abs().max().item()
118
-
119
- # load json info
120
- json_file = file_path.replace('.wav', '.json')
121
- with open(json_file ,'r') as f:
122
- json_str = f.read()
123
- info_json = json.loads(json_str)
124
-
125
- if "phr_start" not in info_json.keys():
126
- info_json["phr_start"] = None
127
-
128
- # return AudioMeta(file_path, info.duration, info.sample_rate, info_json["bpm"], info_json["meter"], amplitude, None, info_json["phr_start"])
129
- return AudioMeta(file_path, info.duration, info.sample_rate, info_json["bpm"], amplitude, None, info_json["phr_start"])
130
-
131
- def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
132
- """If Dora is available as a dependency, try to resolve potential relative paths
133
- in list of AudioMeta. This method is expected to be used when loading meta from file.
134
-
135
- Args:
136
- m (AudioMeta): Audio meta to resolve.
137
- fast (bool): If True, uses a really fast check for determining if a file
138
- is already absolute or not. Only valid on Linux/Mac.
139
- Returns:
140
- AudioMeta: Audio meta with resolved path.
141
- """
142
- def is_abs(m):
143
- if fast:
144
- return str(m)[0] == '/'
145
- else:
146
- os.path.isabs(str(m))
147
-
148
- if not dora:
149
- return m
150
-
151
- if not is_abs(m.path):
152
- m.path = dora.git_save.to_absolute_path(m.path)
153
- if m.info_path is not None and not is_abs(m.info_path.zip_path):
154
- m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
155
- return m
156
-
157
-
158
- def find_audio_files(path: tp.Union[Path, str],
159
- exts: tp.List[str] = DEFAULT_EXTS,
160
- resolve: bool = True,
161
- minimal: bool = True,
162
- progress: bool = False,
163
- workers: int = 0) -> tp.List[AudioMeta]:
164
- """Build a list of AudioMeta from a given path,
165
- collecting relevant audio files and fetching meta info.
166
-
167
- Args:
168
- path (str or Path): Path to folder containing audio files.
169
- exts (list of str): List of file extensions to consider for audio files.
170
- minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
171
- progress (bool): Whether to log progress on audio files collection.
172
- workers (int): number of parallel workers, if 0, use only the current thread.
173
- Returns:
174
- list of AudioMeta: List of audio file path and its metadata.
175
- """
176
- audio_files = []
177
- futures: tp.List[Future] = []
178
- pool: tp.Optional[ThreadPoolExecutor] = None
179
- with ExitStack() as stack:
180
- if workers > 0:
181
- pool = ThreadPoolExecutor(workers)
182
- stack.enter_context(pool)
183
-
184
- if progress:
185
- print("Finding audio files...")
186
- for root, folders, files in os.walk(path, followlinks=True):
187
- for file in files:
188
- full_path = Path(root) / file
189
- if full_path.suffix.lower() in exts:
190
- audio_files.append(full_path)
191
- if pool is not None:
192
- futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
193
- if progress:
194
- print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
195
-
196
- if progress:
197
- print("Getting audio metadata...")
198
- meta: tp.List[AudioMeta] = []
199
- for idx, file_path in enumerate(audio_files):
200
- try:
201
- if pool is None:
202
- m = _get_audio_meta(str(file_path), minimal)
203
- else:
204
- m = futures[idx].result()
205
- if resolve:
206
- m = _resolve_audio_meta(m)
207
- except Exception as err:
208
- print("Error with", str(file_path), err, file=sys.stderr)
209
- continue
210
- meta.append(m)
211
- if progress:
212
- print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
213
- meta.sort()
214
- return meta
215
-
216
-
217
- def load_audio_meta(path: tp.Union[str, Path],
218
- resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
219
- """Load list of AudioMeta from an optionally compressed json file.
220
-
221
- Args:
222
- path (str or Path): Path to JSON file.
223
- resolve (bool): Whether to resolve the path from AudioMeta (default=True).
224
- fast (bool): activates some tricks to make things faster.
225
- Returns:
226
- list of AudioMeta: List of audio file path and its total duration.
227
- """
228
- open_fn = gzip.open if str(path).lower().endswith('.gz') else open
229
- with open_fn(path, 'rb') as fp: # type: ignore
230
- lines = fp.readlines()
231
- meta = []
232
- for line in lines:
233
- d = json.loads(line)
234
- m = AudioMeta.from_dict(d)
235
- if resolve:
236
- m = _resolve_audio_meta(m, fast=fast)
237
- meta.append(m)
238
- return meta
239
-
240
-
241
- def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
242
- """Save the audio metadata to the file pointer as json.
243
-
244
- Args:
245
- path (str or Path): Path to JSON file.
246
- metadata (list of BaseAudioMeta): List of audio meta to save.
247
- """
248
- Path(path).parent.mkdir(exist_ok=True, parents=True)
249
- open_fn = gzip.open if str(path).lower().endswith('.gz') else open
250
- with open_fn(path, 'wb') as fp: # type: ignore
251
- for m in meta:
252
- json_str = json.dumps(m.to_dict()) + '\n'
253
- json_bytes = json_str.encode('utf-8')
254
- fp.write(json_bytes)
255
-
256
-
257
- class AudioDataset:
258
- """Base audio dataset.
259
-
260
- The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
261
- and potentially additional information, by creating random segments from the list of audio
262
- files referenced in the metadata and applying minimal data pre-processing such as resampling,
263
- mixing of channels, padding, etc.
264
-
265
- If no segment_duration value is provided, the AudioDataset will return the full wav for each
266
- audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
267
- duration, applying padding if required.
268
-
269
- By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
270
- allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
271
- original audio meta.
272
-
273
- Note that you can call `start_epoch(epoch)` in order to get
274
- a deterministic "randomization" for `shuffle=True`.
275
- For a given epoch and dataset index, this will always return the same extract.
276
- You can get back some diversity by setting the `shuffle_seed` param.
277
-
278
- Args:
279
- meta (list of AudioMeta): List of audio files metadata.
280
- segment_duration (float, optional): Optional segment duration of audio to load.
281
- If not specified, the dataset will load the full audio segment from the file.
282
- shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
283
- sample_rate (int): Target sample rate of the loaded audio samples.
284
- channels (int): Target number of channels of the loaded audio samples.
285
- sample_on_duration (bool): Set to `True` to sample segments with probability
286
- dependent on audio file duration. This is only used if `segment_duration` is provided.
287
- sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
288
- `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
289
- of the file duration and file weight. This is only used if `segment_duration` is provided.
290
- min_segment_ratio (float): Minimum segment ratio to use when the audio file
291
- is shorter than the desired segment.
292
- max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
293
- return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
294
- min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
295
- audio shorter than this will be filtered out.
296
- max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
297
- audio longer than this will be filtered out.
298
- shuffle_seed (int): can be used to further randomize
299
- load_wav (bool): if False, skip loading the wav but returns a tensor of 0
300
- with the expected segment_duration (which must be provided if load_wav is False).
301
- permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
302
- are False. Will ensure a permutation on files when going through the dataset.
303
- In that case the epoch number must be provided in order for the model
304
- to continue the permutation across epochs. In that case, it is assumed
305
- that `num_samples = total_batch_size * num_updates_per_epoch`, with
306
- `total_batch_size` the overall batch size accounting for all gpus.
307
- """
308
- def __init__(self,
309
- meta: tp.List[AudioMeta],
310
- segment_duration: tp.Optional[float] = None,
311
- shuffle: bool = True,
312
- num_samples: int = 10_000,
313
- sample_rate: int = 48_000,
314
- channels: int = 2,
315
- pad: bool = True,
316
- sample_on_duration: bool = True,
317
- sample_on_weight: bool = True,
318
- min_segment_ratio: float = 1,
319
- max_read_retry: int = 10,
320
- return_info: bool = False,
321
- min_audio_duration: tp.Optional[float] = None,
322
- max_audio_duration: tp.Optional[float] = None,
323
- shuffle_seed: int = 0,
324
- load_wav: bool = True,
325
- permutation_on_files: bool = False,
326
- ):
327
- assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
328
- assert segment_duration is None or segment_duration > 0
329
- assert segment_duration is None or min_segment_ratio >= 0
330
- self.segment_duration = segment_duration
331
- self.min_segment_ratio = min_segment_ratio
332
- self.max_audio_duration = max_audio_duration
333
- self.min_audio_duration = min_audio_duration
334
- if self.min_audio_duration is not None and self.max_audio_duration is not None:
335
- assert self.min_audio_duration <= self.max_audio_duration
336
- self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
337
- assert len(self.meta) # Fail fast if all data has been filtered.
338
- self.total_duration = sum(d.duration for d in self.meta)
339
-
340
- if segment_duration is None:
341
- num_samples = len(self.meta)
342
- self.num_samples = num_samples
343
- self.shuffle = shuffle
344
- self.sample_rate = sample_rate
345
- self.channels = channels
346
- self.pad = pad
347
- self.sample_on_weight = sample_on_weight
348
- self.sample_on_duration = sample_on_duration
349
- self.sampling_probabilities = self._get_sampling_probabilities()
350
- self.max_read_retry = max_read_retry
351
- self.return_info = return_info
352
- self.shuffle_seed = shuffle_seed
353
- self.current_epoch: tp.Optional[int] = None
354
- self.load_wav = load_wav
355
- if not load_wav:
356
- assert segment_duration is not None
357
- self.permutation_on_files = permutation_on_files
358
- if permutation_on_files:
359
- assert not self.sample_on_duration
360
- assert not self.sample_on_weight
361
- assert self.shuffle
362
-
363
- def start_epoch(self, epoch: int):
364
- self.current_epoch = epoch
365
-
366
- def __len__(self):
367
- return self.num_samples
368
-
369
- def _get_sampling_probabilities(self, normalized: bool = True):
370
- """Return the sampling probabilities for each file inside `self.meta`."""
371
- scores: tp.List[float] = []
372
- for file_meta in self.meta:
373
- score = 1.
374
- if self.sample_on_weight and file_meta.weight is not None:
375
- score *= file_meta.weight
376
- if self.sample_on_duration:
377
- score *= file_meta.duration
378
- scores.append(score)
379
- probabilities = torch.tensor(scores)
380
- if normalized:
381
- probabilities /= probabilities.sum()
382
- return probabilities
383
-
384
- @staticmethod
385
- @lru_cache(16)
386
- def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
387
- # Used to keep the most recent files permutation in memory implicitely.
388
- # will work unless someone is using a lot of Datasets in parallel.
389
- rng = torch.Generator()
390
- rng.manual_seed(base_seed + permutation_index)
391
- return torch.randperm(num_files, generator=rng)
392
-
393
- def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
394
- """Sample a given file from `self.meta`. Can be overridden in subclasses.
395
- This is only called if `segment_duration` is not None.
396
-
397
- You must use the provided random number generator `rng` for reproducibility.
398
- You can further make use of the index accessed.
399
- """
400
- if self.permutation_on_files:
401
- assert self.current_epoch is not None
402
- total_index = self.current_epoch * len(self) + index
403
- permutation_index = total_index // len(self.meta)
404
- relative_index = total_index % len(self.meta)
405
- permutation = AudioDataset._get_file_permutation(
406
- len(self.meta), permutation_index, self.shuffle_seed)
407
- file_index = permutation[relative_index]
408
- return self.meta[file_index]
409
-
410
- if not self.sample_on_weight and not self.sample_on_duration:
411
- file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
412
- else:
413
- file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
414
-
415
- return self.meta[file_index]
416
-
417
- def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
418
- # Override this method in subclass if needed.
419
- if self.load_wav:
420
- return audio_read(path, seek_time, duration, pad=False)
421
- else:
422
- assert self.segment_duration is not None
423
- n_frames = int(self.sample_rate * self.segment_duration)
424
- return torch.zeros(self.channels, n_frames), self.sample_rate
425
-
426
- def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
427
- if self.segment_duration is None:
428
- file_meta = self.meta[index]
429
- out, sr = audio_read(file_meta.path)
430
- out = convert_audio(out, sr, self.sample_rate, self.channels)
431
- n_frames = out.shape[-1]
432
- segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
433
- sample_rate=self.sample_rate, channels=out.shape[0])
434
- else:
435
- rng = torch.Generator()
436
- if self.shuffle:
437
- # We use index, plus extra randomness, either totally random if we don't know the epoch.
438
- # otherwise we make use of the epoch number and optional shuffle_seed.
439
- if self.current_epoch is None:
440
- rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
441
- else:
442
- rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
443
- else:
444
- # We only use index
445
- rng.manual_seed(index)
446
-
447
- for retry in range(self.max_read_retry):
448
- file_meta = self.sample_file(index, rng)
449
- # We add some variance in the file position even if audio file is smaller than segment
450
- # without ending up with empty segments
451
-
452
- # sample with phrase
453
- if file_meta.phr_start is not None:
454
- # max_seek = max(0, len(file_meta.phr_start[:-1]))
455
- max_seek = max(0, len([start for start in file_meta.phr_start if start + self.segment_duration <= file_meta.duration])) # sample with time
456
- seek_time = file_meta.phr_start[int(torch.rand(1, generator=rng).item() * max_seek)] # choose from phrase
457
-
458
- else:
459
- max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
460
- seek_time = torch.rand(1, generator=rng).item() * max_seek # can be change to choose phrase start
461
-
462
- if file_meta.duration == self.segment_duration:
463
- seek_time = 0
464
-
465
- # phr_dur = 60./file_meta.bpm * (file_meta.meter * 4.) # if meter=4 then 16 beats per phrase
466
- try:
467
- out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
468
- # out, sr = audio_read(file_meta.path, seek_time, phr_dur, pad=False) # use phrase trunk as input
469
- out = convert_audio(out, sr, self.sample_rate, self.channels)
470
- n_frames = out.shape[-1]
471
- target_frames = int(self.segment_duration * self.sample_rate)
472
- if self.pad:
473
- out = F.pad(out, (0, target_frames - n_frames))
474
- segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
475
- sample_rate=self.sample_rate, channels=out.shape[0])
476
- except Exception as exc:
477
- logger.warning("Error opening file %s: %r", file_meta.path, exc)
478
- if retry == self.max_read_retry - 1:
479
- raise
480
- else:
481
- break
482
-
483
- if self.return_info:
484
- # Returns the wav and additional information on the wave segment
485
- return out, segment_info
486
- else:
487
- return out
488
-
489
- def collater(self, samples):
490
- """The collater function has to be provided to the dataloader
491
- if AudioDataset has return_info=True in order to properly collate
492
- the samples of a batch.
493
- """
494
- if self.segment_duration is None and len(samples) > 1:
495
- assert self.pad, "Must allow padding when batching examples of different durations."
496
-
497
- # In this case the audio reaching the collater is of variable length as segment_duration=None.
498
- to_pad = self.segment_duration is None and self.pad
499
- if to_pad:
500
- max_len = max([wav.shape[-1] for wav, _ in samples])
501
-
502
- def _pad_wav(wav):
503
- return F.pad(wav, (0, max_len - wav.shape[-1]))
504
-
505
- if self.return_info:
506
- if len(samples) > 0:
507
- assert len(samples[0]) == 2
508
- assert isinstance(samples[0][0], torch.Tensor)
509
- assert isinstance(samples[0][1], SegmentInfo)
510
-
511
- wavs = [wav for wav, _ in samples]
512
- segment_infos = [copy.deepcopy(info) for _, info in samples]
513
-
514
- if to_pad:
515
- # Each wav could be of a different duration as they are not segmented.
516
- for i in range(len(samples)):
517
- # Determines the total length of the signal with padding, so we update here as we pad.
518
- segment_infos[i].total_frames = max_len
519
- wavs[i] = _pad_wav(wavs[i])
520
-
521
- wav = torch.stack(wavs)
522
- return wav, segment_infos
523
- else:
524
- assert isinstance(samples[0], torch.Tensor)
525
- if to_pad:
526
- samples = [_pad_wav(s) for s in samples]
527
- return torch.stack(samples)
528
-
529
- def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
530
- """Filters out audio files with audio durations that will not allow to sample examples from them."""
531
- orig_len = len(meta)
532
-
533
- # Filter data that is too short.
534
- if self.min_audio_duration is not None:
535
- meta = [m for m in meta if m.duration >= self.min_audio_duration]
536
-
537
- # Filter data that is too long.
538
- if self.max_audio_duration is not None:
539
- meta = [m for m in meta if m.duration <= self.max_audio_duration]
540
-
541
- filtered_len = len(meta)
542
- removed_percentage = 100*(1-float(filtered_len)/orig_len)
543
- msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
544
- if removed_percentage < 10:
545
- logging.debug(msg)
546
- else:
547
- logging.warning(msg)
548
- return meta
549
-
550
- @classmethod
551
- def from_meta(cls, root: tp.Union[str, Path], **kwargs):
552
- """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
553
-
554
- Args:
555
- root (str or Path): Path to root folder containing audio files.
556
- kwargs: Additional keyword arguments for the AudioDataset.
557
- """
558
- root = Path(root)
559
- if root.is_dir():
560
- if (root / 'data.jsonl').exists():
561
- root = root / 'data.jsonl'
562
- elif (root / 'data.jsonl.gz').exists():
563
- root = root / 'data.jsonl.gz'
564
- else:
565
- raise ValueError("Don't know where to read metadata from in the dir. "
566
- "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
567
- meta = load_audio_meta(root)
568
- return cls(meta, **kwargs)
569
-
570
- @classmethod
571
- def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
572
- exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
573
- """Instantiate AudioDataset from a path containing (possibly nested) audio files.
574
-
575
- Args:
576
- root (str or Path): Path to root folder containing audio files.
577
- minimal_meta (bool): Whether to only load minimal metadata or not.
578
- exts (list of str): Extensions for audio files.
579
- kwargs: Additional keyword arguments for the AudioDataset.
580
- """
581
- root = Path(root)
582
- if root.is_file():
583
- meta = load_audio_meta(root, resolve=True)
584
- else:
585
- meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
586
- return cls(meta, **kwargs)
587
-
588
-
589
- def main():
590
- logging.basicConfig(stream=sys.stderr, level=logging.INFO)
591
- parser = argparse.ArgumentParser(
592
- prog='audio_dataset',
593
- description='Generate .jsonl files by scanning a folder.')
594
- parser.add_argument('root', help='Root folder with all the audio files')
595
- parser.add_argument('output_meta_file',
596
- help='Output file to store the metadata, ')
597
- parser.add_argument('--complete',
598
- action='store_false', dest='minimal', default=True,
599
- help='Retrieve all metadata, even the one that are expansive '
600
- 'to compute (e.g. normalization).')
601
- parser.add_argument('--resolve',
602
- action='store_true', default=False,
603
- help='Resolve the paths to be absolute and with no symlinks.')
604
- parser.add_argument('--workers',
605
- default=10, type=int,
606
- help='Number of workers.')
607
- args = parser.parse_args()
608
- meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
609
- resolve=args.resolve, minimal=args.minimal, workers=args.workers)
610
- save_audio_meta(args.output_meta_file, meta)
611
-
612
-
613
- if __name__ == '__main__':
614
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/audio_utils.py DELETED
@@ -1,385 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Various utilities for audio convertion (pcm format, sample rate and channels),
7
- and volume normalization."""
8
- import sys
9
- import typing as tp
10
-
11
- import julius
12
- import torch
13
- import torchaudio
14
- import numpy as np
15
-
16
- from .chords import Chords
17
- chords = Chords() # initiate object
18
-
19
-
20
- def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
21
- """Convert audio to the given number of channels.
22
-
23
- Args:
24
- wav (torch.Tensor): Audio wave of shape [B, C, T].
25
- channels (int): Expected number of channels as output.
26
- Returns:
27
- torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
28
- """
29
- *shape, src_channels, length = wav.shape
30
- if src_channels == channels:
31
- pass
32
- elif channels == 1:
33
- # Case 1:
34
- # The caller asked 1-channel audio, and the stream has multiple
35
- # channels, downmix all channels.
36
- wav = wav.mean(dim=-2, keepdim=True)
37
- elif src_channels == 1:
38
- # Case 2:
39
- # The caller asked for multiple channels, but the input file has
40
- # a single channel, replicate the audio over all channels.
41
- wav = wav.expand(*shape, channels, length)
42
- elif src_channels >= channels:
43
- # Case 3:
44
- # The caller asked for multiple channels, and the input file has
45
- # more channels than requested. In that case return the first channels.
46
- wav = wav[..., :channels, :]
47
- else:
48
- # Case 4: What is a reasonable choice here?
49
- raise ValueError('The audio file has less channels than requested but is not mono.')
50
- return wav
51
-
52
-
53
- def convert_audio(wav: torch.Tensor, from_rate: float,
54
- to_rate: float, to_channels: int) -> torch.Tensor:
55
- """Convert audio to new sample rate and number of audio channels."""
56
- wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
57
- wav = convert_audio_channels(wav, to_channels)
58
- return wav
59
-
60
-
61
- def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
62
- loudness_compressor: bool = False, energy_floor: float = 2e-3):
63
- """Normalize an input signal to a user loudness in dB LKFS.
64
- Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
65
-
66
- Args:
67
- wav (torch.Tensor): Input multichannel audio data.
68
- sample_rate (int): Sample rate.
69
- loudness_headroom_db (float): Target loudness of the output in dB LUFS.
70
- loudness_compressor (bool): Uses tanh for soft clipping.
71
- energy_floor (float): anything below that RMS level will not be rescaled.
72
- Returns:
73
- torch.Tensor: Loudness normalized output data.
74
- """
75
- energy = wav.pow(2).mean().sqrt().item()
76
- if energy < energy_floor:
77
- return wav
78
- transform = torchaudio.transforms.Loudness(sample_rate)
79
- input_loudness_db = transform(wav).item()
80
- # calculate the gain needed to scale to the desired loudness level
81
- delta_loudness = -loudness_headroom_db - input_loudness_db
82
- gain = 10.0 ** (delta_loudness / 20.0)
83
- output = gain * wav
84
- if loudness_compressor:
85
- output = torch.tanh(output)
86
- assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
87
- return output
88
-
89
-
90
- def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
91
- """Utility function to clip the audio with logging if specified."""
92
- max_scale = wav.abs().max()
93
- if log_clipping and max_scale > 1:
94
- clamp_prob = (wav.abs() > 1).float().mean().item()
95
- print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
96
- clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
97
- wav.clamp_(-1, 1)
98
-
99
-
100
- def normalize_audio(wav: torch.Tensor, normalize: bool = True,
101
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
102
- rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
103
- loudness_compressor: bool = False, log_clipping: bool = False,
104
- sample_rate: tp.Optional[int] = None,
105
- stem_name: tp.Optional[str] = None) -> torch.Tensor:
106
- """Normalize the audio according to the prescribed strategy (see after).
107
-
108
- Args:
109
- wav (torch.Tensor): Audio data.
110
- normalize (bool): if `True` (default), normalizes according to the prescribed
111
- strategy (see after). If `False`, the strategy is only used in case clipping
112
- would happen.
113
- strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
114
- i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
115
- with extra headroom to avoid clipping. 'clip' just clips.
116
- peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
117
- rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
118
- than the `peak_clip` one to avoid further clipping.
119
- loudness_headroom_db (float): Target loudness for loudness normalization.
120
- loudness_compressor (bool): If True, uses tanh based soft clipping.
121
- log_clipping (bool): If True, basic logging on stderr when clipping still
122
- occurs despite strategy (only for 'rms').
123
- sample_rate (int): Sample rate for the audio data (required for loudness).
124
- stem_name (str, optional): Stem name for clipping logging.
125
- Returns:
126
- torch.Tensor: Normalized audio.
127
- """
128
- scale_peak = 10 ** (-peak_clip_headroom_db / 20)
129
- scale_rms = 10 ** (-rms_headroom_db / 20)
130
- if strategy == 'peak':
131
- rescaling = (scale_peak / wav.abs().max())
132
- if normalize or rescaling < 1:
133
- wav = wav * rescaling
134
- elif strategy == 'clip':
135
- wav = wav.clamp(-scale_peak, scale_peak)
136
- elif strategy == 'rms':
137
- mono = wav.mean(dim=0)
138
- rescaling = scale_rms / mono.pow(2).mean().sqrt()
139
- if normalize or rescaling < 1:
140
- wav = wav * rescaling
141
- _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142
- elif strategy == 'loudness':
143
- assert sample_rate is not None, "Loudness normalization requires sample rate."
144
- wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
145
- _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
146
- else:
147
- assert wav.abs().max() < 1
148
- assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
149
- return wav
150
-
151
-
152
- def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
153
- """Convert audio to float 32 bits PCM format.
154
- """
155
- if wav.dtype.is_floating_point:
156
- return wav
157
- elif wav.dtype == torch.int16:
158
- return wav.float() / 2**15
159
- elif wav.dtype == torch.int32:
160
- return wav.float() / 2**31
161
- raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
162
-
163
-
164
- def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
165
- """Convert audio to int 16 bits PCM format.
166
-
167
- ..Warning:: There exist many formula for doing this conversion. None are perfect
168
- due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
169
- or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
170
- it is possible that `i16_pcm(f32_pcm)) != Identity`.
171
- """
172
- if wav.dtype.is_floating_point:
173
- assert wav.abs().max() <= 1
174
- candidate = (wav * 2 ** 15).round()
175
- if candidate.max() >= 2 ** 15: # clipping would occur
176
- candidate = (wav * (2 ** 15 - 1)).round()
177
- return candidate.short()
178
- else:
179
- assert wav.dtype == torch.int16
180
- return wav
181
-
182
- def convert_txtchord2chroma_orig(text_chords, bpms, meters, gen_sec):
183
- chromas = []
184
- # total_len = int(gen_sec * 44100 / 512)
185
- total_len = int(gen_sec * 32000 / 640)
186
- for chord, bpm, meter in zip(text_chords, bpms, meters):
187
- phr_len = int(60. / bpm * (meter * 4) * 32000 / 640)
188
- # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048)
189
- chroma = torch.zeros([total_len, 12])
190
- count = 0
191
- offset = 0
192
-
193
- stext = chord.split(" ")
194
- timebin = phr_len // 4 # frames per bar
195
- while count < total_len:
196
- for tokens in stext:
197
- if count >= total_len:
198
- break
199
- stoken = tokens.split(',')
200
- for token in stoken:
201
- off_timebin = timebin + offset
202
- rounded_timebin = round(off_timebin)
203
- offset = off_timebin - rounded_timebin
204
- offset = offset/len(stoken)
205
- add_step = rounded_timebin//len(stoken)
206
- mhot = chords.chord(token)
207
- rolled = np.roll(mhot[2], mhot[0])
208
- for i in range(count, count + add_step):
209
- if count >= total_len:
210
- break
211
- chroma[i] = torch.Tensor(rolled)
212
- count += 1
213
- chromas.append(chroma)
214
- chroma = torch.stack(chromas)
215
- return chroma
216
-
217
- def convert_txtchord2chroma(chord, bpm, meter, gen_sec):
218
- total_len = int(gen_sec * 32000 / 640)
219
-
220
- phr_len = int(60. / bpm * (meter * 4) * 32000 / 640)
221
- # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048)
222
- chroma = torch.zeros([total_len, 12])
223
- count = 0
224
- offset = 0
225
-
226
- stext = chord.split(" ")
227
- timebin = phr_len // 4 # frames per bar
228
- while count < total_len:
229
- for tokens in stext:
230
- if count >= total_len:
231
- break
232
- stoken = tokens.split(',')
233
- for token in stoken:
234
- off_timebin = timebin + offset
235
- rounded_timebin = round(off_timebin)
236
- offset = off_timebin - rounded_timebin
237
- offset = offset/len(stoken)
238
- add_step = rounded_timebin//len(stoken)
239
- mhot = chords.chord(token)
240
- rolled = np.roll(mhot[2], mhot[0])
241
- for i in range(count, count + add_step):
242
- if count >= total_len:
243
- break
244
- chroma[i] = torch.Tensor(rolled)
245
- count += 1
246
- return chroma
247
-
248
-
249
-
250
- def convert_txtchord2chroma_24(chord, bpm, meter, gen_sec):
251
- total_len = int(gen_sec * 32000 / 640)
252
-
253
- phr_len = int(60. / bpm * (meter * 4) * 32000 / 640)
254
- # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048)
255
- chroma = torch.zeros([total_len, 24])
256
- count = 0
257
- offset = 0
258
-
259
- stext = chord.split(" ")
260
- timebin = phr_len // 4 # frames per bar
261
- while count < total_len:
262
- for tokens in stext:
263
- if count >= total_len:
264
- break
265
- stoken = tokens.split(',')
266
- for token in stoken:
267
- off_timebin = timebin + offset
268
- rounded_timebin = round(off_timebin)
269
- offset = off_timebin - rounded_timebin
270
- offset = offset/len(stoken)
271
- add_step = rounded_timebin//len(stoken)
272
-
273
- root, bass, ivs_vec, _ = chords.chord(token)
274
- root_vec = torch.zeros(12)
275
- root_vec[root] = 1
276
- final_vec = np.concatenate([root_vec, ivs_vec]) # [C]
277
- for i in range(count, count + add_step):
278
- if count >= total_len:
279
- break
280
- chroma[i] = torch.Tensor(final_vec)
281
- count += 1
282
- return chroma
283
-
284
- def get_chroma_chord_from_lab(chord_path, gen_sec):
285
- total_len = int(gen_sec * 32000 / 640)
286
- feat_hz = 32000/640
287
- intervals = []
288
- labels = []
289
- feat_chord = np.zeros((12, total_len)) # root| ivs
290
- with open(chord_path, 'r') as f:
291
- for line in f.readlines():
292
- splits = line.split()
293
- if len(splits) == 3:
294
- st_sec, ed_sec, ctag = splits
295
- st_sec = float(st_sec)
296
- ed_sec = float(ed_sec)
297
-
298
- st_frame = int(st_sec*feat_hz)
299
- ed_frame = int(ed_sec*feat_hz)
300
-
301
- mhot = chords.chord(ctag)
302
- final_vec = np.roll(mhot[2], mhot[0])
303
-
304
- final_vec = final_vec[..., None] # [C, T]
305
- feat_chord[:, st_frame:ed_frame] = final_vec
306
- feat_chord = torch.from_numpy(feat_chord)
307
- return feat_chord
308
-
309
-
310
- def get_chroma_chord_from_text(text_chord, bpm, meter, gen_sec):
311
- total_len = int(gen_sec * 32000 / 640)
312
-
313
- phr_len = int(60. / bpm * (meter * 4) * 32000 / 640)
314
- chroma = np.zeros([12, total_len])
315
- count = 0
316
- offset = 0
317
-
318
- stext = chord.split(" ")
319
- timebin = phr_len // 4 # frames per bar
320
- while count < total_len:
321
- for tokens in stext:
322
- if count >= total_len:
323
- break
324
- stoken = tokens.split(',')
325
- for token in stoken:
326
- off_timebin = timebin + offset
327
- rounded_timebin = round(off_timebin)
328
- offset = off_timebin - rounded_timebin
329
- offset = offset/len(stoken)
330
- add_step = rounded_timebin//len(stoken)
331
- mhot = chords.chord(token)
332
- final_vec = np.roll(mhot[2], mhot[0])
333
- final_vec = final_vec[..., None] # [C, T]
334
-
335
- for i in range(count, count + add_step):
336
- if count >= total_len:
337
- break
338
- chroma[:, i] = final_vec
339
- count += 1
340
- feat_chord = torch.from_numpy(feat_chord)
341
- return feat_chord
342
-
343
- def get_beat_from_npy(beat_path, gen_sec):
344
- total_len = int(gen_sec * 32000 / 640)
345
-
346
- beats_np = np.load(beat_path, allow_pickle=True)
347
- feat_beats = np.zeros((2, total_len))
348
- meter = int(max(beats_np.T[1]))
349
- beat_time = beats_np[:, 0]
350
- bar_time = beats_np[np.where(beats_np[:, 1] == 1)[0], 0]
351
-
352
- beat_frame = [int((t)*feat_hz) for t in beat_time if (t >= 0 and t < duration)]
353
- bar_frame =[int((t)*feat_hz) for t in bar_time if (t >= 0 and t < duration)]
354
-
355
- feat_beats[0, beat_frame] = 1
356
- feat_beats[1, bar_frame] = 1
357
- kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05])
358
- feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel
359
- beat_events = feat_beats[0] + feat_beats[1]
360
- beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T]
361
-
362
- bpm = 60 // np.mean([j-i for i, j in zip(beat_time[:-1], beat_time[1:])])
363
- return beat_events, bpm, meter
364
-
365
- def get_beat_from_bpm(bpm, meter, gen_sec):
366
- total_len = int(gen_sec * 32000 / 640)
367
-
368
- feat_beats = np.zeros((2, total_len))
369
-
370
- beat_time_gap = 60 / bpm
371
- beat_gap = 60 / bpm * feat_hz
372
-
373
- beat_time = np.arange(0, duration, beat_time_gap)
374
- beat_frame = np.round(np.arange(0, n_frames_feat, beat_gap)).astype(int)
375
- if beat_frame[-1] == n_frames_feat:
376
- beat_frame = beat_frame[:-1]
377
- bar_frame = beat_frame[::meter]
378
-
379
- feat_beats[0, beat_frame] = 1
380
- feat_beats[1, bar_frame] = 1
381
- kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05])
382
- feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel
383
- beat_events = feat_beats[0] + feat_beats[1]
384
- beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T]
385
- return beat_events, beat_time, meter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/btc_chords.py DELETED
@@ -1,524 +0,0 @@
1
- # encoding: utf-8
2
- """
3
- This module contains chord evaluation functionality.
4
-
5
- It provides the evaluation measures used for the MIREX ACE task, and
6
- tries to follow [1]_ and [2]_ as closely as possible.
7
-
8
- Notes
9
- -----
10
- This implementation tries to follow the references and their implementation
11
- (e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there
12
- are some known (and possibly some unknown) differences. If you find one not
13
- listed in the following, please file an issue:
14
-
15
- - Detected chord segments are adjusted to fit the length of the annotations.
16
- In particular, this means that, if necessary, filler segments of 'no chord'
17
- are added at beginnings and ends. This can result in different segmentation
18
- scores compared to the original implementation.
19
-
20
- References
21
- ----------
22
- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information
23
- from Music Signals." Dissertation,
24
- Department for Electronic Engineering, Queen Mary University of London,
25
- 2010.
26
- .. [2] Johan Pauwels and Geoffroy Peeters.
27
- "Evaluating Automatically Estimated Chord Sequences."
28
- In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
29
-
30
- """
31
-
32
- import numpy as np
33
- import pandas as pd
34
-
35
-
36
- CHORD_DTYPE = [('root', np.int_),
37
- ('bass', np.int_),
38
- ('intervals', np.int_, (12,)),
39
- ('is_major',np.bool_)]
40
-
41
- CHORD_ANN_DTYPE = [('start', np.float32),
42
- ('end', np.float32),
43
- ('chord', CHORD_DTYPE)]
44
-
45
- NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int_), False)
46
- UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int_) * -1, False)
47
-
48
- PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
49
-
50
-
51
- def idx_to_chord(idx):
52
- if idx == 24:
53
- return "-"
54
- elif idx == 25:
55
- return u"\u03B5"
56
-
57
- minmaj = idx % 2
58
- root = idx // 2
59
-
60
- return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m")
61
-
62
- class Chords:
63
-
64
- def __init__(self):
65
- self._shorthands = {
66
- 'maj': self.interval_list('(1,3,5)'),
67
- 'min': self.interval_list('(1,b3,5)'),
68
- 'dim': self.interval_list('(1,b3,b5)'),
69
- 'aug': self.interval_list('(1,3,#5)'),
70
- 'maj7': self.interval_list('(1,3,5,7)'),
71
- 'min7': self.interval_list('(1,b3,5,b7)'),
72
- '7': self.interval_list('(1,3,5,b7)'),
73
- '6': self.interval_list('(1,6)'), # custom
74
- '5': self.interval_list('(1,5)'),
75
- '4': self.interval_list('(1,4)'), # custom
76
- '1': self.interval_list('(1)'),
77
- 'dim7': self.interval_list('(1,b3,b5,bb7)'),
78
- 'hdim7': self.interval_list('(1,b3,b5,b7)'),
79
- 'minmaj7': self.interval_list('(1,b3,5,7)'),
80
- 'maj6': self.interval_list('(1,3,5,6)'),
81
- 'min6': self.interval_list('(1,b3,5,6)'),
82
- '9': self.interval_list('(1,3,5,b7,9)'),
83
- 'maj9': self.interval_list('(1,3,5,7,9)'),
84
- 'min9': self.interval_list('(1,b3,5,b7,9)'),
85
- 'add9': self.interval_list('(1,3,5,9)'), # custom
86
- 'sus2': self.interval_list('(1,2,5)'),
87
- 'sus4': self.interval_list('(1,4,5)'),
88
- '7sus2': self.interval_list('(1,2,5,b7)'), # custom
89
- '7sus4': self.interval_list('(1,4,5,b7)'), # custom
90
- '11': self.interval_list('(1,3,5,b7,9,11)'),
91
- 'min11': self.interval_list('(1,b3,5,b7,9,11)'),
92
- '13': self.interval_list('(1,3,5,b7,13)'),
93
- 'maj13': self.interval_list('(1,3,5,7,13)'),
94
- 'min13': self.interval_list('(1,b3,5,b7,13)')
95
- }
96
-
97
- def chords(self, labels):
98
-
99
- """
100
- Transform a list of chord labels into an array of internal numeric
101
- representations.
102
-
103
- Parameters
104
- ----------
105
- labels : list
106
- List of chord labels (str).
107
-
108
- Returns
109
- -------
110
- chords : numpy.array
111
- Structured array with columns 'root', 'bass', and 'intervals',
112
- containing a numeric representation of chords.
113
-
114
- """
115
- crds = np.zeros(len(labels), dtype=CHORD_DTYPE)
116
- cache = {}
117
- for i, lbl in enumerate(labels):
118
- cv = cache.get(lbl, None)
119
- if cv is None:
120
- cv = self.chord(lbl)
121
- cache[lbl] = cv
122
- crds[i] = cv
123
-
124
- return crds
125
-
126
- def label_error_modify(self, label):
127
- if label == 'Emin/4': label = 'E:min/4'
128
- elif label == 'A7/3': label = 'A:7/3'
129
- elif label == 'Bb7/3': label = 'Bb:7/3'
130
- elif label == 'Bb7/5': label = 'Bb:7/5'
131
- elif label.find(':') == -1:
132
- if label.find('min') != -1:
133
- label = label[:label.find('min')] + ':' + label[label.find('min'):]
134
- return label
135
-
136
- def chord(self, label):
137
- """
138
- Transform a chord label into the internal numeric represenation of
139
- (root, bass, intervals array).
140
-
141
- Parameters
142
- ----------
143
- label : str
144
- Chord label.
145
-
146
- Returns
147
- -------
148
- chord : tuple
149
- Numeric representation of the chord: (root, bass, intervals array).
150
-
151
- """
152
-
153
-
154
- is_major = False
155
-
156
- if label == 'N':
157
- return NO_CHORD
158
- if label == 'X':
159
- return UNKNOWN_CHORD
160
-
161
- label = self.label_error_modify(label)
162
-
163
- c_idx = label.find(':')
164
- s_idx = label.find('/')
165
-
166
- if c_idx == -1:
167
- quality_str = 'maj'
168
- if s_idx == -1:
169
- root_str = label
170
- bass_str = ''
171
- else:
172
- root_str = label[:s_idx]
173
- bass_str = label[s_idx + 1:]
174
- else:
175
- root_str = label[:c_idx]
176
- if s_idx == -1:
177
- quality_str = label[c_idx + 1:]
178
- bass_str = ''
179
- else:
180
- quality_str = label[c_idx + 1:s_idx]
181
- bass_str = label[s_idx + 1:]
182
-
183
- root = self.pitch(root_str)
184
- bass = self.interval(bass_str) if bass_str else 0
185
- ivs = self.chord_intervals(quality_str)
186
- ivs[bass] = 1
187
-
188
- if 'min' in quality_str:
189
- is_major = False
190
- else:
191
- is_major = True
192
-
193
-
194
- return root, bass, ivs, is_major
195
-
196
- _l = [0, 1, 1, 0, 1, 1, 1]
197
- _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1
198
-
199
- def modify(self, base_pitch, modifier):
200
- """
201
- Modify a pitch class in integer representation by a given modifier string.
202
-
203
- A modifier string can be any sequence of 'b' (one semitone down)
204
- and '#' (one semitone up).
205
-
206
- Parameters
207
- ----------
208
- base_pitch : int
209
- Pitch class as integer.
210
- modifier : str
211
- String of modifiers ('b' or '#').
212
-
213
- Returns
214
- -------
215
- modified_pitch : int
216
- Modified root note.
217
-
218
- """
219
- for m in modifier:
220
- if m == 'b':
221
- base_pitch -= 1
222
- elif m == '#':
223
- base_pitch += 1
224
- else:
225
- raise ValueError('Unknown modifier: {}'.format(m))
226
- return base_pitch
227
-
228
- def pitch(self, pitch_str):
229
- """
230
- Convert a string representation of a pitch class (consisting of root
231
- note and modifiers) to an integer representation.
232
-
233
- Parameters
234
- ----------
235
- pitch_str : str
236
- String representation of a pitch class.
237
-
238
- Returns
239
- -------
240
- pitch : int
241
- Integer representation of a pitch class.
242
-
243
- """
244
- return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],
245
- pitch_str[1:]) % 12
246
-
247
- def interval(self, interval_str):
248
- """
249
- Convert a string representation of a musical interval into a pitch class
250
- (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its
251
- base note).
252
-
253
- Parameters
254
- ----------
255
- interval_str : str
256
- Musical interval.
257
-
258
- Returns
259
- -------
260
- pitch_class : int
261
- Number of semitones to base note of interval.
262
-
263
- """
264
- for i, c in enumerate(interval_str):
265
- if c.isdigit():
266
- return self.modify(self._chroma_id[int(interval_str[i:]) - 1],
267
- interval_str[:i]) % 12
268
-
269
- def interval_list(self, intervals_str, given_pitch_classes=None):
270
- """
271
- Convert a list of intervals given as string to a binary pitch class
272
- representation. For example, 'b3, 5' would become
273
- [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].
274
-
275
- Parameters
276
- ----------
277
- intervals_str : str
278
- List of intervals as comma-separated string (e.g. 'b3, 5').
279
- given_pitch_classes : None or numpy array
280
- If None, start with empty pitch class array, if numpy array of length
281
- 12, this array will be modified.
282
-
283
- Returns
284
- -------
285
- pitch_classes : numpy array
286
- Binary pitch class representation of intervals.
287
-
288
- """
289
- if given_pitch_classes is None:
290
- given_pitch_classes = np.zeros(12, dtype=np.int_)
291
- for int_def in intervals_str[1:-1].split(','):
292
- int_def = int_def.strip()
293
- if int_def[0] == '*':
294
- given_pitch_classes[self.interval(int_def[1:])] = 0
295
- else:
296
- given_pitch_classes[self.interval(int_def)] = 1
297
- return given_pitch_classes
298
-
299
- # mapping of shorthand interval notations to the actual interval representation
300
-
301
- def chord_intervals(self, quality_str):
302
- """
303
- Convert a chord quality string to a pitch class representation. For
304
- example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].
305
-
306
- Parameters
307
- ----------
308
- quality_str : str
309
- String defining the chord quality.
310
-
311
- Returns
312
- -------
313
- pitch_classes : numpy array
314
- Binary pitch class representation of chord quality.
315
-
316
- """
317
- list_idx = quality_str.find('(')
318
- if list_idx == -1:
319
- return self._shorthands[quality_str].copy()
320
- if list_idx != 0:
321
- ivs = self._shorthands[quality_str[:list_idx]].copy()
322
- else:
323
- ivs = np.zeros(12, dtype=np.int_)
324
-
325
-
326
- return self.interval_list(quality_str[list_idx:], ivs)
327
-
328
- def load_chords(self, filename):
329
- """
330
- Load chords from a text file.
331
-
332
- The chord must follow the syntax defined in [1]_.
333
-
334
- Parameters
335
- ----------
336
- filename : str
337
- File containing chord segments.
338
-
339
- Returns
340
- -------
341
- crds : numpy structured array
342
- Structured array with columns "start", "end", and "chord",
343
- containing the beginning, end, and chord definition of chord
344
- segments.
345
-
346
- References
347
- ----------
348
- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony
349
- Information from Music Signals." Dissertation,
350
- Department for Electronic Engineering, Queen Mary University of
351
- London, 2010.
352
-
353
- """
354
- start, end, chord_labels = [], [], []
355
- with open(filename, 'r') as f:
356
- for line in f:
357
- if line:
358
-
359
- splits = line.split()
360
- if len(splits) == 3:
361
-
362
- s = splits[0]
363
- e = splits[1]
364
- l = splits[2]
365
-
366
- start.append(float(s))
367
- end.append(float(e))
368
- chord_labels.append(l)
369
-
370
- crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)
371
- crds['start'] = start
372
- crds['end'] = end
373
- crds['chord'] = self.chords(chord_labels)
374
-
375
- return crds
376
-
377
- def reduce_to_triads(self, chords, keep_bass=False):
378
- """
379
- Reduce chords to triads.
380
-
381
- The function follows the reduction rules implemented in [1]_. If a chord
382
- chord does not contain a third, major second or fourth, it is reduced to
383
- a power chord. If it does not contain neither a third nor a fifth, it is
384
- reduced to a single note "chord".
385
-
386
- Parameters
387
- ----------
388
- chords : numpy structured array
389
- Chords to be reduced.
390
- keep_bass : bool
391
- Indicates whether to keep the bass note or set it to 0.
392
-
393
- Returns
394
- -------
395
- reduced_chords : numpy structured array
396
- Chords reduced to triads.
397
-
398
- References
399
- ----------
400
- .. [1] Johan Pauwels and Geoffroy Peeters.
401
- "Evaluating Automatically Estimated Chord Sequences."
402
- In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
403
-
404
- """
405
- unison = chords['intervals'][:, 0].astype(bool)
406
- maj_sec = chords['intervals'][:, 2].astype(bool)
407
- min_third = chords['intervals'][:, 3].astype(bool)
408
- maj_third = chords['intervals'][:, 4].astype(bool)
409
- perf_fourth = chords['intervals'][:, 5].astype(bool)
410
- dim_fifth = chords['intervals'][:, 6].astype(bool)
411
- perf_fifth = chords['intervals'][:, 7].astype(bool)
412
- aug_fifth = chords['intervals'][:, 8].astype(bool)
413
- no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)
414
-
415
- reduced_chords = chords.copy()
416
- ivs = reduced_chords['intervals']
417
-
418
- ivs[~no_chord] = self.interval_list('(1)')
419
- ivs[unison & perf_fifth] = self.interval_list('(1,5)')
420
- ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']
421
- ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']
422
-
423
- ivs[min_third] = self._shorthands['min']
424
- ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')
425
- ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']
426
-
427
- ivs[maj_third] = self._shorthands['maj']
428
- ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')
429
- ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']
430
-
431
- if not keep_bass:
432
- reduced_chords['bass'] = 0
433
- else:
434
- # remove bass notes if they are not part of the intervals anymore
435
- reduced_chords['bass'] *= ivs[range(len(reduced_chords)),
436
- reduced_chords['bass']]
437
- # keep -1 in bass for no chords
438
- reduced_chords['bass'][no_chord] = -1
439
-
440
- return reduced_chords
441
-
442
- def convert_to_id(self, root, is_major):
443
- if root == -1:
444
- return 24
445
- else:
446
- if is_major:
447
- return root * 2
448
- else:
449
- return root * 2 + 1
450
-
451
- def get_converted_chord(self, filename):
452
- loaded_chord = self.load_chords(filename)
453
- triads = self.reduce_to_triads(loaded_chord['chord'])
454
-
455
- df = self.assign_chord_id(triads)
456
- df['start'] = loaded_chord['start']
457
- df['end'] = loaded_chord['end']
458
-
459
- return df
460
-
461
- def assign_chord_id(self, entry):
462
- # maj, min chord only
463
- # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)
464
- df = pd.DataFrame(data=entry[['root', 'is_major']])
465
- df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)
466
- return df
467
-
468
- def convert_to_id_voca(self, root, quality):
469
- if root == -1:
470
- return 169
471
- else:
472
- if quality == 'min':
473
- return root * 14
474
- elif quality == 'maj':
475
- return root * 14 + 1
476
- elif quality == 'dim':
477
- return root * 14 + 2
478
- elif quality == 'aug':
479
- return root * 14 + 3
480
- elif quality == 'min6':
481
- return root * 14 + 4
482
- elif quality == 'maj6':
483
- return root * 14 + 5
484
- elif quality == 'min7':
485
- return root * 14 + 6
486
- elif quality == 'minmaj7':
487
- return root * 14 + 7
488
- elif quality == 'maj7':
489
- return root * 14 + 8
490
- elif quality == '7':
491
- return root * 14 + 9
492
- elif quality == 'dim7':
493
- return root * 14 + 10
494
- elif quality == 'hdim7':
495
- return root * 14 + 11
496
- elif quality == 'sus2':
497
- return root * 14 + 12
498
- elif quality == 'sus4':
499
- return root * 14 + 13
500
- else:
501
- return 168
502
-
503
-
504
- def lab_file_error_modify(self, ref_labels):
505
- for i in range(len(ref_labels)):
506
- if ref_labels[i][-2:] == ':4':
507
- ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
508
- elif ref_labels[i][-2:] == ':6':
509
- ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
510
- elif ref_labels[i][-4:] == ':6/2':
511
- ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
512
- elif ref_labels[i] == 'Emin/4':
513
- ref_labels[i] = 'E:min/4'
514
- elif ref_labels[i] == 'A7/3':
515
- ref_labels[i] = 'A:7/3'
516
- elif ref_labels[i] == 'Bb7/3':
517
- ref_labels[i] = 'Bb:7/3'
518
- elif ref_labels[i] == 'Bb7/5':
519
- ref_labels[i] = 'Bb:7/5'
520
- elif ref_labels[i].find(':') == -1:
521
- if ref_labels[i].find('min') != -1:
522
- ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
523
- return ref_labels
524
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/chords.py DELETED
@@ -1,524 +0,0 @@
1
- # encoding: utf-8
2
- """
3
- This module contains chord evaluation functionality.
4
-
5
- It provides the evaluation measures used for the MIREX ACE task, and
6
- tries to follow [1]_ and [2]_ as closely as possible.
7
-
8
- Notes
9
- -----
10
- This implementation tries to follow the references and their implementation
11
- (e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there
12
- are some known (and possibly some unknown) differences. If you find one not
13
- listed in the following, please file an issue:
14
-
15
- - Detected chord segments are adjusted to fit the length of the annotations.
16
- In particular, this means that, if necessary, filler segments of 'no chord'
17
- are added at beginnings and ends. This can result in different segmentation
18
- scores compared to the original implementation.
19
-
20
- References
21
- ----------
22
- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information
23
- from Music Signals." Dissertation,
24
- Department for Electronic Engineering, Queen Mary University of London,
25
- 2010.
26
- .. [2] Johan Pauwels and Geoffroy Peeters.
27
- "Evaluating Automatically Estimated Chord Sequences."
28
- In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
29
-
30
- """
31
-
32
- import numpy as np
33
- import pandas as pd
34
-
35
-
36
- CHORD_DTYPE = [('root', np.int_),
37
- ('bass', np.int_),
38
- ('intervals', np.int_, (12,)),
39
- ('is_major',np.bool_)]
40
-
41
- CHORD_ANN_DTYPE = [('start', np.float32),
42
- ('end', np.float32),
43
- ('chord', CHORD_DTYPE)]
44
-
45
- NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int_), False)
46
- UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int_) * -1, False)
47
-
48
- PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
49
-
50
-
51
- def idx_to_chord(idx):
52
- if idx == 24:
53
- return "-"
54
- elif idx == 25:
55
- return u"\u03B5"
56
-
57
- minmaj = idx % 2
58
- root = idx // 2
59
-
60
- return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m")
61
-
62
- class Chords:
63
-
64
- def __init__(self):
65
- self._shorthands = {
66
- 'maj': self.interval_list('(1,3,5)'),
67
- 'min': self.interval_list('(1,b3,5)'),
68
- 'dim': self.interval_list('(1,b3,b5)'),
69
- 'aug': self.interval_list('(1,3,#5)'),
70
- 'maj7': self.interval_list('(1,3,5,7)'),
71
- 'min7': self.interval_list('(1,b3,5,b7)'),
72
- '7': self.interval_list('(1,3,5,b7)'),
73
- '6': self.interval_list('(1,6)'), # custom
74
- '5': self.interval_list('(1,5)'),
75
- '4': self.interval_list('(1,4)'), # custom
76
- '1': self.interval_list('(1)'),
77
- 'dim7': self.interval_list('(1,b3,b5,bb7)'),
78
- 'hdim7': self.interval_list('(1,b3,b5,b7)'),
79
- 'minmaj7': self.interval_list('(1,b3,5,7)'),
80
- 'maj6': self.interval_list('(1,3,5,6)'),
81
- 'min6': self.interval_list('(1,b3,5,6)'),
82
- '9': self.interval_list('(1,3,5,b7,9)'),
83
- 'maj9': self.interval_list('(1,3,5,7,9)'),
84
- 'min9': self.interval_list('(1,b3,5,b7,9)'),
85
- 'add9': self.interval_list('(1,3,5,9)'), # custom
86
- 'sus2': self.interval_list('(1,2,5)'),
87
- 'sus4': self.interval_list('(1,4,5)'),
88
- '7sus2': self.interval_list('(1,2,5,b7)'), # custom
89
- '7sus4': self.interval_list('(1,4,5,b7)'), # custom
90
- '11': self.interval_list('(1,3,5,b7,9,11)'),
91
- 'min11': self.interval_list('(1,b3,5,b7,9,11)'),
92
- '13': self.interval_list('(1,3,5,b7,13)'),
93
- 'maj13': self.interval_list('(1,3,5,7,13)'),
94
- 'min13': self.interval_list('(1,b3,5,b7,13)')
95
- }
96
-
97
- def chords(self, labels):
98
-
99
- """
100
- Transform a list of chord labels into an array of internal numeric
101
- representations.
102
-
103
- Parameters
104
- ----------
105
- labels : list
106
- List of chord labels (str).
107
-
108
- Returns
109
- -------
110
- chords : numpy.array
111
- Structured array with columns 'root', 'bass', and 'intervals',
112
- containing a numeric representation of chords.
113
-
114
- """
115
- crds = np.zeros(len(labels), dtype=CHORD_DTYPE)
116
- cache = {}
117
- for i, lbl in enumerate(labels):
118
- cv = cache.get(lbl, None)
119
- if cv is None:
120
- cv = self.chord(lbl)
121
- cache[lbl] = cv
122
- crds[i] = cv
123
-
124
- return crds
125
-
126
- def label_error_modify(self, label):
127
- if label == 'Emin/4': label = 'E:min/4'
128
- elif label == 'A7/3': label = 'A:7/3'
129
- elif label == 'Bb7/3': label = 'Bb:7/3'
130
- elif label == 'Bb7/5': label = 'Bb:7/5'
131
- elif label.find(':') == -1:
132
- if label.find('min') != -1:
133
- label = label[:label.find('min')] + ':' + label[label.find('min'):]
134
- return label
135
-
136
- def chord(self, label):
137
- """
138
- Transform a chord label into the internal numeric represenation of
139
- (root, bass, intervals array).
140
-
141
- Parameters
142
- ----------
143
- label : str
144
- Chord label.
145
-
146
- Returns
147
- -------
148
- chord : tuple
149
- Numeric representation of the chord: (root, bass, intervals array).
150
-
151
- """
152
-
153
-
154
- is_major = False
155
-
156
- if label == 'N':
157
- return NO_CHORD
158
- if label == 'X':
159
- return UNKNOWN_CHORD
160
-
161
- label = self.label_error_modify(label)
162
-
163
- c_idx = label.find(':')
164
- s_idx = label.find('/')
165
-
166
- if c_idx == -1:
167
- quality_str = 'maj'
168
- if s_idx == -1:
169
- root_str = label
170
- bass_str = ''
171
- else:
172
- root_str = label[:s_idx]
173
- bass_str = label[s_idx + 1:]
174
- else:
175
- root_str = label[:c_idx]
176
- if s_idx == -1:
177
- quality_str = label[c_idx + 1:]
178
- bass_str = ''
179
- else:
180
- quality_str = label[c_idx + 1:s_idx]
181
- bass_str = label[s_idx + 1:]
182
-
183
- root = self.pitch(root_str)
184
- bass = self.interval(bass_str) if bass_str else 0
185
- ivs = self.chord_intervals(quality_str)
186
- ivs[bass] = 1
187
-
188
- if 'min' in quality_str:
189
- is_major = False
190
- else:
191
- is_major = True
192
-
193
-
194
- return root, bass, ivs, is_major
195
-
196
- _l = [0, 1, 1, 0, 1, 1, 1]
197
- _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1
198
-
199
- def modify(self, base_pitch, modifier):
200
- """
201
- Modify a pitch class in integer representation by a given modifier string.
202
-
203
- A modifier string can be any sequence of 'b' (one semitone down)
204
- and '#' (one semitone up).
205
-
206
- Parameters
207
- ----------
208
- base_pitch : int
209
- Pitch class as integer.
210
- modifier : str
211
- String of modifiers ('b' or '#').
212
-
213
- Returns
214
- -------
215
- modified_pitch : int
216
- Modified root note.
217
-
218
- """
219
- for m in modifier:
220
- if m == 'b':
221
- base_pitch -= 1
222
- elif m == '#':
223
- base_pitch += 1
224
- else:
225
- raise ValueError('Unknown modifier: {}'.format(m))
226
- return base_pitch
227
-
228
- def pitch(self, pitch_str):
229
- """
230
- Convert a string representation of a pitch class (consisting of root
231
- note and modifiers) to an integer representation.
232
-
233
- Parameters
234
- ----------
235
- pitch_str : str
236
- String representation of a pitch class.
237
-
238
- Returns
239
- -------
240
- pitch : int
241
- Integer representation of a pitch class.
242
-
243
- """
244
- return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],
245
- pitch_str[1:]) % 12
246
-
247
- def interval(self, interval_str):
248
- """
249
- Convert a string representation of a musical interval into a pitch class
250
- (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its
251
- base note).
252
-
253
- Parameters
254
- ----------
255
- interval_str : str
256
- Musical interval.
257
-
258
- Returns
259
- -------
260
- pitch_class : int
261
- Number of semitones to base note of interval.
262
-
263
- """
264
- for i, c in enumerate(interval_str):
265
- if c.isdigit():
266
- return self.modify(self._chroma_id[int(interval_str[i:]) - 1],
267
- interval_str[:i]) % 12
268
-
269
- def interval_list(self, intervals_str, given_pitch_classes=None):
270
- """
271
- Convert a list of intervals given as string to a binary pitch class
272
- representation. For example, 'b3, 5' would become
273
- [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].
274
-
275
- Parameters
276
- ----------
277
- intervals_str : str
278
- List of intervals as comma-separated string (e.g. 'b3, 5').
279
- given_pitch_classes : None or numpy array
280
- If None, start with empty pitch class array, if numpy array of length
281
- 12, this array will be modified.
282
-
283
- Returns
284
- -------
285
- pitch_classes : numpy array
286
- Binary pitch class representation of intervals.
287
-
288
- """
289
- if given_pitch_classes is None:
290
- given_pitch_classes = np.zeros(12, dtype=np.int_)
291
- for int_def in intervals_str[1:-1].split(','):
292
- int_def = int_def.strip()
293
- if int_def[0] == '*':
294
- given_pitch_classes[self.interval(int_def[1:])] = 0
295
- else:
296
- given_pitch_classes[self.interval(int_def)] = 1
297
- return given_pitch_classes
298
-
299
- # mapping of shorthand interval notations to the actual interval representation
300
-
301
- def chord_intervals(self, quality_str):
302
- """
303
- Convert a chord quality string to a pitch class representation. For
304
- example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].
305
-
306
- Parameters
307
- ----------
308
- quality_str : str
309
- String defining the chord quality.
310
-
311
- Returns
312
- -------
313
- pitch_classes : numpy array
314
- Binary pitch class representation of chord quality.
315
-
316
- """
317
- list_idx = quality_str.find('(')
318
- if list_idx == -1:
319
- return self._shorthands[quality_str].copy()
320
- if list_idx != 0:
321
- ivs = self._shorthands[quality_str[:list_idx]].copy()
322
- else:
323
- ivs = np.zeros(12, dtype=np.int_)
324
-
325
-
326
- return self.interval_list(quality_str[list_idx:], ivs)
327
-
328
- def load_chords(self, filename):
329
- """
330
- Load chords from a text file.
331
-
332
- The chord must follow the syntax defined in [1]_.
333
-
334
- Parameters
335
- ----------
336
- filename : str
337
- File containing chord segments.
338
-
339
- Returns
340
- -------
341
- crds : numpy structured array
342
- Structured array with columns "start", "end", and "chord",
343
- containing the beginning, end, and chord definition of chord
344
- segments.
345
-
346
- References
347
- ----------
348
- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony
349
- Information from Music Signals." Dissertation,
350
- Department for Electronic Engineering, Queen Mary University of
351
- London, 2010.
352
-
353
- """
354
- start, end, chord_labels = [], [], []
355
- with open(filename, 'r') as f:
356
- for line in f:
357
- if line:
358
-
359
- splits = line.split()
360
- if len(splits) == 3:
361
-
362
- s = splits[0]
363
- e = splits[1]
364
- l = splits[2]
365
-
366
- start.append(float(s))
367
- end.append(float(e))
368
- chord_labels.append(l)
369
-
370
- crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)
371
- crds['start'] = start
372
- crds['end'] = end
373
- crds['chord'] = self.chords(chord_labels)
374
-
375
- return crds
376
-
377
- def reduce_to_triads(self, chords, keep_bass=False):
378
- """
379
- Reduce chords to triads.
380
-
381
- The function follows the reduction rules implemented in [1]_. If a chord
382
- chord does not contain a third, major second or fourth, it is reduced to
383
- a power chord. If it does not contain neither a third nor a fifth, it is
384
- reduced to a single note "chord".
385
-
386
- Parameters
387
- ----------
388
- chords : numpy structured array
389
- Chords to be reduced.
390
- keep_bass : bool
391
- Indicates whether to keep the bass note or set it to 0.
392
-
393
- Returns
394
- -------
395
- reduced_chords : numpy structured array
396
- Chords reduced to triads.
397
-
398
- References
399
- ----------
400
- .. [1] Johan Pauwels and Geoffroy Peeters.
401
- "Evaluating Automatically Estimated Chord Sequences."
402
- In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
403
-
404
- """
405
- unison = chords['intervals'][:, 0].astype(bool)
406
- maj_sec = chords['intervals'][:, 2].astype(bool)
407
- min_third = chords['intervals'][:, 3].astype(bool)
408
- maj_third = chords['intervals'][:, 4].astype(bool)
409
- perf_fourth = chords['intervals'][:, 5].astype(bool)
410
- dim_fifth = chords['intervals'][:, 6].astype(bool)
411
- perf_fifth = chords['intervals'][:, 7].astype(bool)
412
- aug_fifth = chords['intervals'][:, 8].astype(bool)
413
- no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)
414
-
415
- reduced_chords = chords.copy()
416
- ivs = reduced_chords['intervals']
417
-
418
- ivs[~no_chord] = self.interval_list('(1)')
419
- ivs[unison & perf_fifth] = self.interval_list('(1,5)')
420
- ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']
421
- ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']
422
-
423
- ivs[min_third] = self._shorthands['min']
424
- ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')
425
- ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']
426
-
427
- ivs[maj_third] = self._shorthands['maj']
428
- ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')
429
- ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']
430
-
431
- if not keep_bass:
432
- reduced_chords['bass'] = 0
433
- else:
434
- # remove bass notes if they are not part of the intervals anymore
435
- reduced_chords['bass'] *= ivs[range(len(reduced_chords)),
436
- reduced_chords['bass']]
437
- # keep -1 in bass for no chords
438
- reduced_chords['bass'][no_chord] = -1
439
-
440
- return reduced_chords
441
-
442
- def convert_to_id(self, root, is_major):
443
- if root == -1:
444
- return 24
445
- else:
446
- if is_major:
447
- return root * 2
448
- else:
449
- return root * 2 + 1
450
-
451
- def get_converted_chord(self, filename):
452
- loaded_chord = self.load_chords(filename)
453
- triads = self.reduce_to_triads(loaded_chord['chord'])
454
-
455
- df = self.assign_chord_id(triads)
456
- df['start'] = loaded_chord['start']
457
- df['end'] = loaded_chord['end']
458
-
459
- return df
460
-
461
- def assign_chord_id(self, entry):
462
- # maj, min chord only
463
- # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)
464
- df = pd.DataFrame(data=entry[['root', 'is_major']])
465
- df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)
466
- return df
467
-
468
- def convert_to_id_voca(self, root, quality):
469
- if root == -1:
470
- return 169
471
- else:
472
- if quality == 'min':
473
- return root * 14
474
- elif quality == 'maj':
475
- return root * 14 + 1
476
- elif quality == 'dim':
477
- return root * 14 + 2
478
- elif quality == 'aug':
479
- return root * 14 + 3
480
- elif quality == 'min6':
481
- return root * 14 + 4
482
- elif quality == 'maj6':
483
- return root * 14 + 5
484
- elif quality == 'min7':
485
- return root * 14 + 6
486
- elif quality == 'minmaj7':
487
- return root * 14 + 7
488
- elif quality == 'maj7':
489
- return root * 14 + 8
490
- elif quality == '7':
491
- return root * 14 + 9
492
- elif quality == 'dim7':
493
- return root * 14 + 10
494
- elif quality == 'hdim7':
495
- return root * 14 + 11
496
- elif quality == 'sus2':
497
- return root * 14 + 12
498
- elif quality == 'sus4':
499
- return root * 14 + 13
500
- else:
501
- return 168
502
-
503
-
504
- def lab_file_error_modify(self, ref_labels):
505
- for i in range(len(ref_labels)):
506
- if ref_labels[i][-2:] == ':4':
507
- ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
508
- elif ref_labels[i][-2:] == ':6':
509
- ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
510
- elif ref_labels[i][-4:] == ':6/2':
511
- ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
512
- elif ref_labels[i] == 'Emin/4':
513
- ref_labels[i] = 'E:min/4'
514
- elif ref_labels[i] == 'A7/3':
515
- ref_labels[i] = 'A:7/3'
516
- elif ref_labels[i] == 'Bb7/3':
517
- ref_labels[i] = 'Bb:7/3'
518
- elif ref_labels[i] == 'Bb7/5':
519
- ref_labels[i] = 'Bb:7/5'
520
- elif ref_labels[i].find(':') == -1:
521
- if ref_labels[i].find('min') != -1:
522
- ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
523
- return ref_labels
524
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/info_audio_dataset.py DELETED
@@ -1,110 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Base classes for the datasets that also provide non-audio metadata,
7
- e.g. description, text transcription etc.
8
- """
9
- from dataclasses import dataclass
10
- import logging
11
- import math
12
- import re
13
- import typing as tp
14
-
15
- import torch
16
-
17
- from .audio_dataset import AudioDataset, AudioMeta
18
- from ..environment import AudioCraftEnvironment
19
- from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
-
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
- """Monkey-patch meta to match cluster specificities."""
27
- meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
- if meta.info_path is not None:
29
- meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
- return meta
31
-
32
-
33
- def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
- """Monkey-patch all meta to match cluster specificities."""
35
- return [_clusterify_meta(m) for m in meta]
36
-
37
-
38
- @dataclass
39
- class AudioInfo(SegmentWithAttributes):
40
- """Dummy SegmentInfo with empty attributes.
41
-
42
- The InfoAudioDataset is expected to return metadata that inherits
43
- from SegmentWithAttributes class and can return conditioning attributes.
44
-
45
- This basically guarantees all datasets will be compatible with current
46
- solver that contain conditioners requiring this.
47
- """
48
- audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
-
50
- def to_condition_attributes(self) -> ConditioningAttributes:
51
- return ConditioningAttributes()
52
-
53
-
54
- class InfoAudioDataset(AudioDataset):
55
- """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
-
57
- See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
- """
59
- def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
- super().__init__(clusterify_all_meta(meta), **kwargs)
61
-
62
- def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
- if not self.return_info:
64
- wav = super().__getitem__(index)
65
- assert isinstance(wav, torch.Tensor)
66
- return wav
67
- wav, meta = super().__getitem__(index)
68
- return wav, AudioInfo(**meta.to_dict())
69
-
70
-
71
- def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
- """Preprocess a single keyword or possible a list of keywords."""
73
- if isinstance(value, list):
74
- return get_keyword_list(value)
75
- else:
76
- return get_keyword(value)
77
-
78
-
79
- def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
- """Preprocess a single keyword."""
81
- if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
- return None
83
- else:
84
- return value.strip()
85
-
86
-
87
- def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
- """Preprocess a single keyword."""
89
- if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
- return None
91
- else:
92
- return value.strip().lower()
93
-
94
-
95
- def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
- """Preprocess a list of keywords."""
97
- if isinstance(values, str):
98
- values = [v.strip() for v in re.split(r'[,\s]', values)]
99
- elif isinstance(values, float) and math.isnan(values):
100
- values = []
101
- if not isinstance(values, list):
102
- logger.debug(f"Unexpected keyword list {values}")
103
- values = [str(values)]
104
-
105
- kws = [get_keyword(v) for v in values]
106
- kw_list = [k for k in kws if k is not None]
107
- if len(kw_list) == 0:
108
- return None
109
- else:
110
- return kw_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/music_dataset.py DELETED
@@ -1,349 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Dataset of music tracks with rich metadata.
7
- """
8
- from dataclasses import dataclass, field, fields, replace
9
- import gzip
10
- import json
11
- import logging
12
- from pathlib import Path
13
- import random
14
- import typing as tp
15
- import pretty_midi
16
- import numpy as np
17
-
18
- import torch
19
- import torch.nn.functional as F
20
- from .btc_chords import Chords
21
-
22
- from .info_audio_dataset import (
23
- InfoAudioDataset,
24
- AudioInfo,
25
- get_keyword_list,
26
- get_keyword,
27
- get_string
28
- )
29
- from ..modules.conditioners import (
30
- ConditioningAttributes,
31
- JointEmbedCondition,
32
- WavCondition,
33
- ChordCondition,
34
- BeatCondition
35
- )
36
- from ..utils.utils import warn_once
37
-
38
-
39
- logger = logging.getLogger(__name__)
40
-
41
- CHORDS = Chords()
42
-
43
-
44
- @dataclass
45
- class MusicInfo(AudioInfo):
46
- """Segment info augmented with music metadata.
47
- """
48
- # music-specific metadata
49
- title: tp.Optional[str] = None
50
- artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
51
- key: tp.Optional[str] = None
52
- bpm: tp.Optional[float] = None
53
- genre: tp.Optional[str] = None
54
- moods: tp.Optional[list] = None
55
- keywords: tp.Optional[list] = None
56
- description: tp.Optional[str] = None
57
- name: tp.Optional[str] = None
58
- instrument: tp.Optional[str] = None
59
- chord: tp.Optional[ChordCondition] = None
60
- beat: tp.Optional[BeatCondition] = None
61
- # original wav accompanying the metadata
62
- self_wav: tp.Optional[WavCondition] = None
63
- # dict mapping attributes names to tuple of wav, text and metadata
64
- joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
65
-
66
- @property
67
- def has_music_meta(self) -> bool:
68
- return self.name is not None
69
-
70
- def to_condition_attributes(self) -> ConditioningAttributes:
71
- out = ConditioningAttributes()
72
- for _field in fields(self):
73
- key, value = _field.name, getattr(self, _field.name)
74
- if key == 'self_wav':
75
- out.wav[key] = value
76
- elif key == 'chord':
77
- out.chord[key] = value
78
- elif key == 'beat':
79
- out.beat[key] = value
80
- elif key == 'joint_embed':
81
- for embed_attribute, embed_cond in value.items():
82
- out.joint_embed[embed_attribute] = embed_cond
83
- else:
84
- if isinstance(value, list):
85
- value = ' '.join(value)
86
- out.text[key] = value
87
- return out
88
-
89
- @staticmethod
90
- def attribute_getter(attribute):
91
- if attribute == 'bpm':
92
- preprocess_func = get_bpm
93
- elif attribute == 'key':
94
- preprocess_func = get_musical_key
95
- elif attribute in ['moods', 'keywords']:
96
- preprocess_func = get_keyword_list
97
- elif attribute in ['genre', 'name', 'instrument']:
98
- preprocess_func = get_keyword
99
- elif attribute in ['title', 'artist', 'description']:
100
- preprocess_func = get_string
101
- else:
102
- preprocess_func = None
103
- return preprocess_func
104
-
105
- @classmethod
106
- def from_dict(cls, dictionary: dict, fields_required: bool = False):
107
- _dictionary: tp.Dict[str, tp.Any] = {}
108
-
109
- # allow a subset of attributes to not be loaded from the dictionary
110
- # these attributes may be populated later
111
- post_init_attributes = ['self_wav', 'chord', 'beat', 'joint_embed']
112
- optional_fields = ['keywords']
113
-
114
- for _field in fields(cls):
115
- if _field.name in post_init_attributes:
116
- continue
117
- elif _field.name not in dictionary:
118
- if fields_required and _field.name not in optional_fields:
119
- raise KeyError(f"Unexpected missing key: {_field.name}")
120
- else:
121
- preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
122
- value = dictionary[_field.name]
123
- if preprocess_func:
124
- value = preprocess_func(value)
125
- _dictionary[_field.name] = value
126
- return cls(**_dictionary)
127
-
128
-
129
- def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
130
- drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
131
- """Augment MusicInfo description with additional metadata fields and potential dropout.
132
- Additional textual attributes are added given probability 'merge_text_conditions_p' and
133
- the original textual description is dropped from the augmented description given probability drop_desc_p.
134
-
135
- Args:
136
- music_info (MusicInfo): The music metadata to augment.
137
- merge_text_p (float): Probability of merging additional metadata to the description.
138
- If provided value is 0, then no merging is performed.
139
- drop_desc_p (float): Probability of dropping the original description on text merge.
140
- if provided value is 0, then no drop out is performed.
141
- drop_other_p (float): Probability of dropping the other fields used for text augmentation.
142
- Returns:
143
- MusicInfo: The MusicInfo with augmented textual description.
144
- """
145
- def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
146
- valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
147
- valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
148
- keep_field = random.uniform(0, 1) < drop_other_p
149
- return valid_field_name and valid_field_value and keep_field
150
-
151
- def process_value(v: tp.Any) -> str:
152
- if isinstance(v, (int, float, str)):
153
- return str(v)
154
- if isinstance(v, list):
155
- return ", ".join(v)
156
- else:
157
- raise ValueError(f"Unknown type for text value! ({type(v), v})")
158
-
159
- description = music_info.description
160
-
161
- metadata_text = ""
162
- # metadata_text = "rock style music, consistent rhythm, catchy song."
163
- if random.uniform(0, 1) < merge_text_p:
164
- meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
165
- for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
166
- random.shuffle(meta_pairs)
167
- metadata_text = ". ".join(meta_pairs)
168
- description = description if not random.uniform(0, 1) < drop_desc_p else None
169
- logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
170
-
171
- if description is None:
172
- description = metadata_text if len(metadata_text) > 1 else None
173
- else:
174
- description = ". ".join([description.rstrip('.'), metadata_text])
175
- description = description.strip() if description else None
176
-
177
- music_info = replace(music_info)
178
- music_info.description = description
179
- return music_info
180
-
181
-
182
- class Paraphraser:
183
- def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
184
- self.paraphrase_p = paraphrase_p
185
- open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
186
- with open_fn(paraphrase_source, 'rb') as f: # type: ignore
187
- self.paraphrase_source = json.loads(f.read())
188
- logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
189
-
190
- def sample_paraphrase(self, audio_path: str, description: str):
191
- if random.random() >= self.paraphrase_p:
192
- return description
193
- info_path = Path(audio_path).with_suffix('.json')
194
- if info_path not in self.paraphrase_source:
195
- warn_once(logger, f"{info_path} not in paraphrase source!")
196
- return description
197
- new_desc = random.choice(self.paraphrase_source[info_path])
198
- logger.debug(f"{description} -> {new_desc}")
199
- return new_desc
200
-
201
-
202
- class MusicDataset(InfoAudioDataset):
203
- """Music dataset is an AudioDataset with music-related metadata.
204
-
205
- Args:
206
- info_fields_required (bool): Whether to enforce having required fields.
207
- merge_text_p (float): Probability of merging additional metadata to the description.
208
- drop_desc_p (float): Probability of dropping the original description on text merge.
209
- drop_other_p (float): Probability of dropping the other fields used for text augmentation.
210
- joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
211
- paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
212
- paraphrases for the description. The json should be a dict with keys are the
213
- original info path (e.g. track_path.json) and each value is a list of possible
214
- paraphrased.
215
- paraphrase_p (float): probability of taking a paraphrase.
216
-
217
- See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
218
- """
219
- def __init__(self, *args, info_fields_required: bool = True,
220
- merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
221
- joint_embed_attributes: tp.List[str] = [],
222
- paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
223
- **kwargs):
224
- kwargs['return_info'] = True # We require the info for each song of the dataset.
225
- super().__init__(*args, **kwargs)
226
- self.info_fields_required = info_fields_required
227
- self.merge_text_p = merge_text_p
228
- self.drop_desc_p = drop_desc_p
229
- self.drop_other_p = drop_other_p
230
- self.joint_embed_attributes = joint_embed_attributes
231
- self.paraphraser = None
232
- self.downsample_rate = 640
233
- self.sr = 32000
234
- if paraphrase_source is not None:
235
- self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
236
-
237
- def __getitem__(self, index):
238
- wav, info = super().__getitem__(index) # wav_seg and seg_info
239
- info_data = info.to_dict()
240
-
241
- # unpack info
242
- target_sr = self.sr
243
- n_frames_wave = info.n_frames
244
- n_frames_feat = int(info.n_frames // self.downsample_rate)
245
-
246
- music_info_path = str(info.meta.path).replace('no_vocal.wav', 'tags.json')
247
- chord_path = str(info.meta.path).replace('no_vocal.wav', 'chord.lab')
248
- beats_path = str(info.meta.path).replace('no_vocal.wav', 'beats.npy')
249
-
250
- if all([
251
- not Path(music_info_path).exists(),
252
- not Path(beats_path).exists(),
253
- not Path(chord_path).exists(),
254
- ]):
255
- raise FileNotFoundError
256
-
257
- ### music info
258
- with open(music_info_path, 'r') as json_file:
259
- music_data = json.load(json_file)
260
- music_data.update(info_data)
261
- music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
262
- if self.paraphraser is not None:
263
- music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
264
- if self.merge_text_p:
265
- music_info = augment_music_info_description(
266
- music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
267
-
268
-
269
- ### load features to tensors ###
270
- feat_hz = target_sr/self.downsample_rate
271
- ## beat&bar: 2 x T
272
- feat_beats = np.zeros((2, n_frames_feat))
273
-
274
- beats_np = np.load(beats_path)
275
- beat_time = beats_np[:, 0]
276
- bar_time = beats_np[np.where(beats_np[:, 1] == 1)[0], 0]
277
- beat_frame = [
278
- int((t-info.seek_time)*feat_hz) for t in beat_time
279
- if (t >= info.seek_time and t < info.seek_time + self.segment_duration)]
280
- bar_frame =[
281
- int((t-info.seek_time)*feat_hz) for t in bar_time
282
- if (t >= info.seek_time and t < info.seek_time + self.segment_duration)]
283
- feat_beats[0, beat_frame] = 1
284
- feat_beats[1, bar_frame] = 1
285
- kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05])
286
- feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel
287
- beat_events = feat_beats[0] + feat_beats[1]
288
- beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T]
289
-
290
- music_info.beat = BeatCondition(beat=beat_events[None], length=torch.tensor([n_frames_feat]),
291
- bpm=[music_data["bpm"]], path=[music_info_path], seek_frame=[info.seek_time*target_sr//self.downsample_rate])
292
-
293
- ## chord: 12 x T
294
- feat_chord = np.zeros((12, n_frames_feat)) # root| ivs
295
- with open(chord_path, 'r') as f:
296
- for line in f.readlines():
297
- splits = line.split()
298
- if len(splits) == 3:
299
- st_sec, ed_sec, ctag = splits
300
- st_sec = float(st_sec) - info.seek_time
301
- ed_sec = float(ed_sec) - info.seek_time
302
- st_frame = int(st_sec*feat_hz)
303
- ed_frame = int(ed_sec*feat_hz)
304
-
305
- # 12 chorma
306
- mhot = CHORDS.chord(ctag)
307
- final_vec = np.roll(mhot[2], mhot[0])
308
-
309
- final_vec = final_vec[..., None]
310
- feat_chord[:, st_frame:ed_frame] = final_vec
311
- feat_chord = torch.from_numpy(feat_chord)
312
-
313
- music_info.chord = ChordCondition(
314
- chord=feat_chord[None], length=torch.tensor([n_frames_feat]),
315
- bpm=[music_data["bpm"]], path=[chord_path], seek_frame=[info.seek_time*self.sr//self.downsample_rate])
316
-
317
- music_info.self_wav = WavCondition(
318
- wav=wav[None], length=torch.tensor([info.n_frames]),
319
- sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
320
-
321
- for att in self.joint_embed_attributes:
322
- att_value = getattr(music_info, att)
323
- joint_embed_cond = JointEmbedCondition(
324
- wav[None], [att_value], torch.tensor([info.n_frames]),
325
- sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
326
- music_info.joint_embed[att] = joint_embed_cond
327
-
328
- return wav, music_info
329
-
330
-
331
- def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
332
- """Preprocess key keywords, discarding them if there are multiple key defined."""
333
- if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
334
- return None
335
- elif ',' in value:
336
- # For now, we discard when multiple keys are defined separated with comas
337
- return None
338
- else:
339
- return value.strip().lower()
340
-
341
-
342
- def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
343
- """Preprocess to a float."""
344
- if value is None:
345
- return None
346
- try:
347
- return float(value)
348
- except ValueError:
349
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/sound_dataset.py DELETED
@@ -1,330 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Dataset of audio with a simple description.
7
- """
8
-
9
- from dataclasses import dataclass, fields, replace
10
- import json
11
- from pathlib import Path
12
- import random
13
- import typing as tp
14
-
15
- import numpy as np
16
- import torch
17
-
18
- from .info_audio_dataset import (
19
- InfoAudioDataset,
20
- get_keyword_or_keyword_list
21
- )
22
- from ..modules.conditioners import (
23
- ConditioningAttributes,
24
- SegmentWithAttributes,
25
- WavCondition,
26
- )
27
-
28
-
29
- EPS = torch.finfo(torch.float32).eps
30
- TARGET_LEVEL_LOWER = -35
31
- TARGET_LEVEL_UPPER = -15
32
-
33
-
34
- @dataclass
35
- class SoundInfo(SegmentWithAttributes):
36
- """Segment info augmented with Sound metadata.
37
- """
38
- description: tp.Optional[str] = None
39
- self_wav: tp.Optional[torch.Tensor] = None
40
-
41
- @property
42
- def has_sound_meta(self) -> bool:
43
- return self.description is not None
44
-
45
- def to_condition_attributes(self) -> ConditioningAttributes:
46
- out = ConditioningAttributes()
47
-
48
- for _field in fields(self):
49
- key, value = _field.name, getattr(self, _field.name)
50
- if key == 'self_wav':
51
- out.wav[key] = value
52
- else:
53
- out.text[key] = value
54
- return out
55
-
56
- @staticmethod
57
- def attribute_getter(attribute):
58
- if attribute == 'description':
59
- preprocess_func = get_keyword_or_keyword_list
60
- else:
61
- preprocess_func = None
62
- return preprocess_func
63
-
64
- @classmethod
65
- def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
- _dictionary: tp.Dict[str, tp.Any] = {}
67
-
68
- # allow a subset of attributes to not be loaded from the dictionary
69
- # these attributes may be populated later
70
- post_init_attributes = ['self_wav']
71
-
72
- for _field in fields(cls):
73
- if _field.name in post_init_attributes:
74
- continue
75
- elif _field.name not in dictionary:
76
- if fields_required:
77
- raise KeyError(f"Unexpected missing key: {_field.name}")
78
- else:
79
- preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
- value = dictionary[_field.name]
81
- if preprocess_func:
82
- value = preprocess_func(value)
83
- _dictionary[_field.name] = value
84
- return cls(**_dictionary)
85
-
86
-
87
- class SoundDataset(InfoAudioDataset):
88
- """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
-
90
- Args:
91
- info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
- external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
- The metadata files contained in this folder are expected to match the stem of the audio file with
94
- a json extension.
95
- aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
- mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
- mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
- mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
- mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
- kwargs: Additional arguments for AudioDataset.
101
-
102
- See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
- """
104
- def __init__(
105
- self,
106
- *args,
107
- info_fields_required: bool = True,
108
- external_metadata_source: tp.Optional[str] = None,
109
- aug_p: float = 0.,
110
- mix_p: float = 0.,
111
- mix_snr_low: int = -5,
112
- mix_snr_high: int = 5,
113
- mix_min_overlap: float = 0.5,
114
- **kwargs
115
- ):
116
- kwargs['return_info'] = True # We require the info for each song of the dataset.
117
- super().__init__(*args, **kwargs)
118
- self.info_fields_required = info_fields_required
119
- self.external_metadata_source = external_metadata_source
120
- self.aug_p = aug_p
121
- self.mix_p = mix_p
122
- if self.aug_p > 0:
123
- assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
- assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
- self.mix_snr_low = mix_snr_low
126
- self.mix_snr_high = mix_snr_high
127
- self.mix_min_overlap = mix_min_overlap
128
-
129
- def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
- """Get path of JSON with metadata (description, etc.).
131
- If there exists a JSON with the same name as 'path.name', then it will be used.
132
- Else, such JSON will be searched for in an external json source folder if it exists.
133
- """
134
- info_path = Path(path).with_suffix('.json')
135
- if Path(info_path).exists():
136
- return info_path
137
- elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
- return Path(self.external_metadata_source) / info_path.name
139
- else:
140
- raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
-
142
- def __getitem__(self, index):
143
- wav, info = super().__getitem__(index)
144
- info_data = info.to_dict()
145
- info_path = self._get_info_path(info.meta.path)
146
- if Path(info_path).exists():
147
- with open(info_path, 'r') as json_file:
148
- sound_data = json.load(json_file)
149
- sound_data.update(info_data)
150
- sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
- # if there are multiple descriptions, sample one randomly
152
- if isinstance(sound_info.description, list):
153
- sound_info.description = random.choice(sound_info.description)
154
- else:
155
- sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
-
157
- sound_info.self_wav = WavCondition(
158
- wav=wav[None], length=torch.tensor([info.n_frames]),
159
- sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
-
161
- return wav, sound_info
162
-
163
- def collater(self, samples):
164
- # when training, audio mixing is performed in the collate function
165
- wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
- if self.aug_p > 0:
167
- wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
- snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
- min_overlap=self.mix_min_overlap)
170
- return wav, sound_info
171
-
172
-
173
- def rms_f(x: torch.Tensor) -> torch.Tensor:
174
- return (x ** 2).mean(1).pow(0.5)
175
-
176
-
177
- def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
- """Normalize the signal to the target level."""
179
- rms = rms_f(audio)
180
- scalar = 10 ** (target_level / 20) / (rms + EPS)
181
- audio = audio * scalar.unsqueeze(1)
182
- return audio
183
-
184
-
185
- def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
- return (abs(audio) > clipping_threshold).any(1)
187
-
188
-
189
- def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
- start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
- remainder = src.shape[1] - start
192
- if dst.shape[1] > remainder:
193
- src[:, start:] = src[:, start:] + dst[:, :remainder]
194
- else:
195
- src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
- return src
197
-
198
-
199
- def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
- target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
- """Function to mix clean speech and noise at various SNR levels.
202
-
203
- Args:
204
- clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
- noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
- snr (int): SNR level when mixing.
207
- min_overlap (float): Minimum overlap between the two mixed sources.
208
- target_level (int): Gain level in dB.
209
- clipping_threshold (float): Threshold for clipping the audio.
210
- Returns:
211
- torch.Tensor: The mixed audio, of shape [B, T].
212
- """
213
- if clean.shape[1] > noise.shape[1]:
214
- noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
- else:
216
- noise = noise[:, :clean.shape[1]]
217
-
218
- # normalizing to -25 dB FS
219
- clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
- clean = normalize(clean, target_level)
221
- rmsclean = rms_f(clean)
222
-
223
- noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
- noise = normalize(noise, target_level)
225
- rmsnoise = rms_f(noise)
226
-
227
- # set the noise level for a given SNR
228
- noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
- noisenewlevel = noise * noisescalar
230
-
231
- # mix noise and clean speech
232
- noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
-
234
- # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
- # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
- noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
- rmsnoisy = rms_f(noisyspeech)
238
- scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
- noisyspeech = noisyspeech * scalarnoisy
240
- clean = clean * scalarnoisy
241
- noisenewlevel = noisenewlevel * scalarnoisy
242
-
243
- # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
- clipped = is_clipped(noisyspeech)
245
- if clipped.any():
246
- noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
- noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
-
249
- return noisyspeech
250
-
251
-
252
- def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
- if snr_low == snr_high:
254
- snr = snr_low
255
- else:
256
- snr = np.random.randint(snr_low, snr_high)
257
- mix = snr_mixer(src, dst, snr, min_overlap)
258
- return mix
259
-
260
-
261
- def mix_text(src_text: str, dst_text: str):
262
- """Mix text from different sources by concatenating them."""
263
- if src_text == dst_text:
264
- return src_text
265
- return src_text + " " + dst_text
266
-
267
-
268
- def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
- snr_low: int, snr_high: int, min_overlap: float):
270
- """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
-
272
- Args:
273
- wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
- infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
- aug_p (float): Augmentation probability.
276
- mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
- snr_low (int): Lowerbound for sampling SNR.
278
- snr_high (int): Upperbound for sampling SNR.
279
- min_overlap (float): Minimum overlap between mixed samples.
280
- Returns:
281
- tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
- and mixed SoundInfo for the given batch.
283
- """
284
- # no mixing to perform within the batch
285
- if mix_p == 0:
286
- return wavs, infos
287
-
288
- if random.uniform(0, 1) < aug_p:
289
- # perform all augmentations on waveforms as [B, T]
290
- # randomly picking pairs of audio to mix
291
- assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
- wavs = wavs.mean(dim=1, keepdim=False)
293
- B, T = wavs.shape
294
- k = int(mix_p * B)
295
- mixed_sources_idx = torch.randperm(B)[:k]
296
- mixed_targets_idx = torch.randperm(B)[:k]
297
- aug_wavs = snr_mix(
298
- wavs[mixed_sources_idx],
299
- wavs[mixed_targets_idx],
300
- snr_low,
301
- snr_high,
302
- min_overlap,
303
- )
304
- # mixing textual descriptions in metadata
305
- descriptions = [info.description for info in infos]
306
- aug_infos = []
307
- for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
- text = mix_text(descriptions[i], descriptions[j])
309
- m = replace(infos[i])
310
- m.description = text
311
- aug_infos.append(m)
312
-
313
- # back to [B, C, T]
314
- aug_wavs = aug_wavs.unsqueeze(1)
315
- assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
- assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
- assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
-
319
- return aug_wavs, aug_infos # [B, C, T]
320
- else:
321
- # randomly pick samples in the batch to match
322
- # the batch size when performing audio mixing
323
- B, C, T = wavs.shape
324
- k = int(mix_p * B)
325
- wav_idx = torch.randperm(B)[:k]
326
- wavs = wavs[wav_idx]
327
- infos = [infos[i] for i in wav_idx]
328
- assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
-
330
- return wavs, infos # [B, C, T]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/data/zip.py DELETED
@@ -1,76 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Utility for reading some info from inside a zip file.
7
- """
8
-
9
- import typing
10
- import zipfile
11
-
12
- from dataclasses import dataclass
13
- from functools import lru_cache
14
- from typing_extensions import Literal
15
-
16
-
17
- DEFAULT_SIZE = 32
18
- MODE = Literal['r', 'w', 'x', 'a']
19
-
20
-
21
- @dataclass(order=True)
22
- class PathInZip:
23
- """Hold a path of file within a zip file.
24
-
25
- Args:
26
- path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
- Let's assume there is a zip file /some/location/foo.zip
28
- and inside of it is a json file located at /data/file1.json,
29
- Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
- """
31
-
32
- INFO_PATH_SEP = ':'
33
- zip_path: str
34
- file_path: str
35
-
36
- def __init__(self, path: str) -> None:
37
- split_path = path.split(self.INFO_PATH_SEP)
38
- assert len(split_path) == 2
39
- self.zip_path, self.file_path = split_path
40
-
41
- @classmethod
42
- def from_paths(cls, zip_path: str, file_path: str):
43
- return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44
-
45
- def __str__(self) -> str:
46
- return self.zip_path + self.INFO_PATH_SEP + self.file_path
47
-
48
-
49
- def _open_zip(path: str, mode: MODE = 'r'):
50
- return zipfile.ZipFile(path, mode)
51
-
52
-
53
- _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54
-
55
-
56
- def set_zip_cache_size(max_size: int):
57
- """Sets the maximal LRU caching for zip file opening.
58
-
59
- Args:
60
- max_size (int): the maximal LRU cache.
61
- """
62
- global _cached_open_zip
63
- _cached_open_zip = lru_cache(max_size)(_open_zip)
64
-
65
-
66
- def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67
- """Opens a file stored inside a zip and returns a file-like object.
68
-
69
- Args:
70
- path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
- mode (str): The mode in which to open the file with.
72
- Returns:
73
- A file-like object for PathInZip.
74
- """
75
- zf = _cached_open_zip(path_in_zip.zip_path)
76
- return zf.open(path_in_zip.file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/environment.py DELETED
@@ -1,176 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
- """
10
-
11
- import logging
12
- import os
13
- from pathlib import Path
14
- import re
15
- import typing as tp
16
-
17
- import omegaconf
18
-
19
- from .utils.cluster import _guess_cluster_type
20
-
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- class AudioCraftEnvironment:
26
- """Environment configuration for teams and clusters.
27
-
28
- AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
- or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
- provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
- allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
- map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
-
34
- The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
- Use the following environment variables to specify the cluster, team or configuration:
36
-
37
- AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
- cannot be inferred automatically.
39
- AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
- If not set, configuration is read from config/teams.yaml.
41
- AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
- Cluster configuration are shared across teams to match compute allocation,
43
- specify your cluster configuration in the configuration file under a key mapping
44
- your team name.
45
- """
46
- _instance = None
47
- DEFAULT_TEAM = "default"
48
-
49
- def __init__(self) -> None:
50
- """Loads configuration."""
51
- self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
- cluster_type = _guess_cluster_type()
53
- cluster = os.getenv(
54
- "AUDIOCRAFT_CLUSTER", cluster_type.value
55
- )
56
- logger.info("Detecting cluster type %s", cluster_type)
57
-
58
- self.cluster: str = cluster
59
-
60
- config_path = os.getenv(
61
- "AUDIOCRAFT_CONFIG",
62
- Path(__file__)
63
- .parent.parent.joinpath("config/teams", self.team)
64
- .with_suffix(".yaml"),
65
- )
66
- self.config = omegaconf.OmegaConf.load(config_path)
67
- self._dataset_mappers = []
68
- cluster_config = self._get_cluster_config()
69
- if "dataset_mappers" in cluster_config:
70
- for pattern, repl in cluster_config["dataset_mappers"].items():
71
- regex = re.compile(pattern)
72
- self._dataset_mappers.append((regex, repl))
73
-
74
- def _get_cluster_config(self) -> omegaconf.DictConfig:
75
- assert isinstance(self.config, omegaconf.DictConfig)
76
- return self.config[self.cluster]
77
-
78
- @classmethod
79
- def instance(cls):
80
- if cls._instance is None:
81
- cls._instance = cls()
82
- return cls._instance
83
-
84
- @classmethod
85
- def reset(cls):
86
- """Clears the environment and forces a reload on next invocation."""
87
- cls._instance = None
88
-
89
- @classmethod
90
- def get_team(cls) -> str:
91
- """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
- If not defined, defaults to "labs".
93
- """
94
- return cls.instance().team
95
-
96
- @classmethod
97
- def get_cluster(cls) -> str:
98
- """Gets the detected cluster.
99
- This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
- """
101
- return cls.instance().cluster
102
-
103
- @classmethod
104
- def get_dora_dir(cls) -> Path:
105
- """Gets the path to the dora directory for the current team and cluster.
106
- Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
- """
108
- cluster_config = cls.instance()._get_cluster_config()
109
- dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
- logger.warning(f"Dora directory: {dora_dir}")
111
- return Path(dora_dir)
112
-
113
- @classmethod
114
- def get_reference_dir(cls) -> Path:
115
- """Gets the path to the reference directory for the current team and cluster.
116
- Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
- """
118
- cluster_config = cls.instance()._get_cluster_config()
119
- return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
-
121
- @classmethod
122
- def get_slurm_exclude(cls) -> tp.Optional[str]:
123
- """Get the list of nodes to exclude for that cluster."""
124
- cluster_config = cls.instance()._get_cluster_config()
125
- return cluster_config.get("slurm_exclude")
126
-
127
- @classmethod
128
- def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
- """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
-
131
- Args:
132
- partition_types (list[str], optional): partition types to retrieve. Values must be
133
- from ['global', 'team']. If not provided, the global partition is returned.
134
- """
135
- if not partition_types:
136
- partition_types = ["global"]
137
-
138
- cluster_config = cls.instance()._get_cluster_config()
139
- partitions = [
140
- cluster_config["partitions"][partition_type]
141
- for partition_type in partition_types
142
- ]
143
- return ",".join(partitions)
144
-
145
- @classmethod
146
- def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
- """Converts reference placeholder in path with configured reference dir to resolve paths.
148
-
149
- Args:
150
- path (str or Path): Path to resolve.
151
- Returns:
152
- Path: Resolved path.
153
- """
154
- path = str(path)
155
-
156
- if path.startswith("//reference"):
157
- reference_dir = cls.get_reference_dir()
158
- logger.warn(f"Reference directory: {reference_dir}")
159
- assert (
160
- reference_dir.exists() and reference_dir.is_dir()
161
- ), f"Reference directory does not exist: {reference_dir}."
162
- path = re.sub("^//reference", str(reference_dir), path)
163
-
164
- return Path(path)
165
-
166
- @classmethod
167
- def apply_dataset_mappers(cls, path: str) -> str:
168
- """Applies dataset mapping regex rules as defined in the configuration.
169
- If no rules are defined, the path is returned as-is.
170
- """
171
- instance = cls.instance()
172
-
173
- for pattern, repl in instance._dataset_mappers:
174
- path = pattern.sub(repl, path)
175
-
176
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Dora Grids."""
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/_base_explorers.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from abc import ABC, abstractmethod
8
- import time
9
- import typing as tp
10
- from dora import Explorer
11
- import treetable as tt
12
-
13
-
14
- def get_sheep_ping(sheep) -> tp.Optional[str]:
15
- """Return the amount of time since the Sheep made some update
16
- to its log. Returns a str using the relevant time unit."""
17
- ping = None
18
- if sheep.log is not None and sheep.log.exists():
19
- delta = time.time() - sheep.log.stat().st_mtime
20
- if delta > 3600 * 24:
21
- ping = f'{delta / (3600 * 24):.1f}d'
22
- elif delta > 3600:
23
- ping = f'{delta / (3600):.1f}h'
24
- elif delta > 60:
25
- ping = f'{delta / 60:.1f}m'
26
- else:
27
- ping = f'{delta:.1f}s'
28
- return ping
29
-
30
-
31
- class BaseExplorer(ABC, Explorer):
32
- """Base explorer for AudioCraft grids.
33
-
34
- All task specific solvers are expected to implement the `get_grid_metrics`
35
- method to specify logic about metrics to display for a given task.
36
-
37
- If additional stages are used, the child explorer must define how to handle
38
- these new stages in the `process_history` and `process_sheep` methods.
39
- """
40
- def stages(self):
41
- return ["train", "valid", "evaluate"]
42
-
43
- def get_grid_meta(self):
44
- """Returns the list of Meta information to display for each XP/job.
45
- """
46
- return [
47
- tt.leaf("index", align=">"),
48
- tt.leaf("name", wrap=140),
49
- tt.leaf("state"),
50
- tt.leaf("sig", align=">"),
51
- tt.leaf("sid", align="<"),
52
- ]
53
-
54
- @abstractmethod
55
- def get_grid_metrics(self):
56
- """Return the metrics that should be displayed in the tracking table.
57
- """
58
- ...
59
-
60
- def process_sheep(self, sheep, history):
61
- train = {
62
- "epoch": len(history),
63
- }
64
- parts = {"train": train}
65
- for metrics in history:
66
- for key, sub in metrics.items():
67
- part = parts.get(key, {})
68
- if 'duration' in sub:
69
- # Convert to minutes for readability.
70
- sub['duration'] = sub['duration'] / 60.
71
- part.update(sub)
72
- parts[key] = part
73
- ping = get_sheep_ping(sheep)
74
- if ping is not None:
75
- for name in self.stages():
76
- if name not in parts:
77
- parts[name] = {}
78
- # Add the ping to each part for convenience.
79
- parts[name]['ping'] = ping
80
- return parts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/audiogen/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """AudioGen grids."""
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/audiogen/audiogen_base_16khz.py DELETED
@@ -1,23 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from ..musicgen._explorers import LMExplorer
8
- from ...environment import AudioCraftEnvironment
9
-
10
-
11
- @LMExplorer
12
- def explorer(launcher):
13
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
- launcher.slurm_(gpus=64, partition=partitions)
15
- launcher.bind_(solver='audiogen/audiogen_base_16khz')
16
- # replace this by the desired environmental sound dataset
17
- launcher.bind_(dset='internal/sounds_16khz')
18
-
19
- fsdp = {'autocast': False, 'fsdp.use': True}
20
- medium = {'model/lm/model_scale': 'medium'}
21
-
22
- launcher.bind_(fsdp)
23
- launcher(medium)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py DELETED
@@ -1,68 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Evaluation with objective metrics for the pretrained AudioGen models.
9
- This grid takes signature from the training grid and runs evaluation-only stage.
10
-
11
- When running the grid for the first time, please use:
12
- REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
13
- and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
-
15
- Note that you need the proper metrics external libraries setup to use all
16
- the objective metrics activated in this grid. Refer to the README for more information.
17
- """
18
-
19
- import os
20
-
21
- from ..musicgen._explorers import GenerationEvalExplorer
22
- from ...environment import AudioCraftEnvironment
23
- from ... import train
24
-
25
-
26
- def eval(launcher, batch_size: int = 32):
27
- opts = {
28
- 'dset': 'audio/audiocaps_16khz',
29
- 'solver/audiogen/evaluation': 'objective_eval',
30
- 'execute_only': 'evaluate',
31
- '+dataset.evaluate.batch_size': batch_size,
32
- '+metrics.fad.tf.batch_size': 32,
33
- }
34
- # binary for FAD computation: replace this path with your own path
35
- metrics_opts = {
36
- 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
37
- }
38
- opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
39
- opt2 = {'transformer_lm.two_step_cfg': True}
40
-
41
- sub = launcher.bind(opts)
42
- sub.bind_(metrics_opts)
43
-
44
- # base objective metrics
45
- sub(opt1, opt2)
46
-
47
-
48
- @GenerationEvalExplorer
49
- def explorer(launcher):
50
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
51
- launcher.slurm_(gpus=4, partition=partitions)
52
-
53
- if 'REGEN' not in os.environ:
54
- folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
55
- with launcher.job_array():
56
- for sig in folder.iterdir():
57
- if not sig.is_symlink():
58
- continue
59
- xp = train.main.get_xp_from_sig(sig.name)
60
- launcher(xp.argv)
61
- return
62
-
63
- audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
64
- audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
65
-
66
- audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
67
- audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
68
- eval(audiogen_base_medium, batch_size=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """EnCodec grids."""
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/_explorers.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import treetable as tt
8
-
9
- from .._base_explorers import BaseExplorer
10
-
11
-
12
- class CompressionExplorer(BaseExplorer):
13
- eval_metrics = ["sisnr", "visqol"]
14
-
15
- def stages(self):
16
- return ["train", "valid", "evaluate"]
17
-
18
- def get_grid_meta(self):
19
- """Returns the list of Meta information to display for each XP/job.
20
- """
21
- return [
22
- tt.leaf("index", align=">"),
23
- tt.leaf("name", wrap=140),
24
- tt.leaf("state"),
25
- tt.leaf("sig", align=">"),
26
- ]
27
-
28
- def get_grid_metrics(self):
29
- """Return the metrics that should be displayed in the tracking table.
30
- """
31
- return [
32
- tt.group(
33
- "train",
34
- [
35
- tt.leaf("epoch"),
36
- tt.leaf("bandwidth", ".2f"),
37
- tt.leaf("adv", ".4f"),
38
- tt.leaf("d_loss", ".4f"),
39
- ],
40
- align=">",
41
- ),
42
- tt.group(
43
- "valid",
44
- [
45
- tt.leaf("bandwidth", ".2f"),
46
- tt.leaf("adv", ".4f"),
47
- tt.leaf("msspec", ".4f"),
48
- tt.leaf("sisnr", ".2f"),
49
- ],
50
- align=">",
51
- ),
52
- tt.group(
53
- "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
54
- ),
55
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/debug.py DELETED
@@ -1,31 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Grid search file, simply list all the exp you want in `explorer`.
9
- Any new exp added there will be scheduled.
10
- You can cancel and experiment by commenting its line.
11
-
12
- This grid is a minimal example for debugging compression task
13
- and how to override parameters directly in a grid.
14
- Learn more about dora grids: https://github.com/facebookresearch/dora
15
- """
16
-
17
- from ._explorers import CompressionExplorer
18
- from ...environment import AudioCraftEnvironment
19
-
20
-
21
- @CompressionExplorer
22
- def explorer(launcher):
23
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
24
- launcher.slurm_(gpus=2, partition=partitions)
25
- launcher.bind_(solver='compression/debug')
26
-
27
- with launcher.job_array():
28
- # base debug task using config from solver=compression/debug
29
- launcher()
30
- # we can override parameters in the grid to launch additional xps
31
- launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/encodec_audiogen_16khz.py DELETED
@@ -1,29 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Grid search file, simply list all the exp you want in `explorer`.
9
- Any new exp added there will be scheduled.
10
- You can cancel and experiment by commenting its line.
11
-
12
- This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
13
- """
14
-
15
- from ._explorers import CompressionExplorer
16
- from ...environment import AudioCraftEnvironment
17
-
18
-
19
- @CompressionExplorer
20
- def explorer(launcher):
21
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
- launcher.slurm_(gpus=8, partition=partitions)
23
- # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
24
- # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
25
- launcher.bind_(solver='compression/encodec_audiogen_16khz')
26
- # replace this by the desired sound dataset
27
- launcher.bind_(dset='internal/sounds_16khz')
28
- # launch xp
29
- launcher()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/encodec_base_24khz.py DELETED
@@ -1,28 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Grid search file, simply list all the exp you want in `explorer`.
9
- Any new exp added there will be scheduled.
10
- You can cancel and experiment by commenting its line.
11
-
12
- This grid shows how to train a base causal EnCodec model at 24 kHz.
13
- """
14
-
15
- from ._explorers import CompressionExplorer
16
- from ...environment import AudioCraftEnvironment
17
-
18
-
19
- @CompressionExplorer
20
- def explorer(launcher):
21
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
- launcher.slurm_(gpus=8, partition=partitions)
23
- # base causal EnCodec trained on monophonic audio sampled at 24 kHz
24
- launcher.bind_(solver='compression/encodec_base_24khz')
25
- # replace this by the desired dataset
26
- launcher.bind_(dset='audio/example')
27
- # launch xp
28
- launcher()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/audiocraft/grids/compression/encodec_musicgen_32khz.py DELETED
@@ -1,34 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Grid search file, simply list all the exp you want in `explorer`.
9
- Any new exp added there will be scheduled.
10
- You can cancel and experiment by commenting its line.
11
-
12
- This grid shows how to train a MusicGen EnCodec model at 32 kHz.
13
- """
14
-
15
- from ._explorers import CompressionExplorer
16
- from ...environment import AudioCraftEnvironment
17
-
18
-
19
- @CompressionExplorer
20
- def explorer(launcher):
21
- partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
- launcher.slurm_(gpus=8, partition=partitions)
23
- # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
24
- # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
25
- launcher.bind_(solver='compression/encodec_musicgen_32khz')
26
- # replace this by the desired music dataset
27
- launcher.bind_(dset='internal/music_400k_32khz')
28
- # launch xp
29
- launcher()
30
- launcher({
31
- 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
32
- 'label': 'visqol',
33
- 'evaluate.metrics.visqol': True
34
- })