jeduardogruiz commited on
Commit
6d8a42e
1 Parent(s): acfbd62

Create encoded.py

Browse files
Files changed (1) hide show
  1. encoded.py +105 -0
encoded.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [email protected]:facebookresearch/encodec.git
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ """Various utilities."""
10
+
11
+ from hashlib import sha256
12
+ from pathlib import Path
13
+ import typing as tp
14
+
15
+ import torch
16
+ import torchaudio
17
+
18
+
19
+ def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
20
+ # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
21
+ # e.g., more than 2 frames per position.
22
+ # The core idea is to use a weight function that is a triangle,
23
+ # with a maximum value at the middle of the segment.
24
+ # We use this weighting when summing the frames, and divide by the sum of weights
25
+ # for each positions at the end. Thus:
26
+ # - if a frame is the only one to cover a position, the weighting is a no-op.
27
+ # - if 2 frames cover a position:
28
+ # ... ...
29
+ # / \/ \
30
+ # / /\ \
31
+ # S T , i.e. S offset of second frame starts, T end of first frame.
32
+ # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
33
+ # After the final normalization, the weight of the second frame at position `t` is
34
+ # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
35
+ #
36
+ # - if more than 2 frames overlap at a given point, we hope that by induction
37
+ # something sensible happens.
38
+ assert len(frames)
39
+ device = frames[0].device
40
+ dtype = frames[0].dtype
41
+ shape = frames[0].shape[:-1]
42
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
43
+
44
+ frame_length = frames[0].shape[-1]
45
+ t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
46
+ weight = 0.5 - (t - 0.5).abs()
47
+
48
+ sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
49
+ out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
50
+ offset: int = 0
51
+
52
+ for frame in frames:
53
+ frame_length = frame.shape[-1]
54
+ out[..., offset:offset + frame_length] += weight[:frame_length] * frame
55
+ sum_weight[offset:offset + frame_length] += weight[:frame_length]
56
+ offset += stride
57
+ assert sum_weight.min() > 0
58
+ return out / sum_weight
59
+
60
+
61
+ def _get_checkpoint_url(root_url: str, checkpoint: str):
62
+ if not root_url.endswith('/'):
63
+ root_url += '/'
64
+ return root_url + checkpoint
65
+
66
+
67
+ def _check_checksum(path: Path, checksum: str):
68
+ sha = sha256()
69
+ with open(path, 'rb') as file:
70
+ while True:
71
+ buf = file.read(2**20)
72
+ if not buf:
73
+ break
74
+ sha.update(buf)
75
+ actual_checksum = sha.hexdigest()[:len(checksum)]
76
+ if actual_checksum != checksum:
77
+ raise RuntimeError(f'Invalid checksum for file {path}, '
78
+ f'expected {checksum} but got {actual_checksum}')
79
+
80
+
81
+ def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
82
+ assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions"
83
+ assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo."
84
+ *shape, channels, length = wav.shape
85
+ if target_channels == 1:
86
+ wav = wav.mean(-2, keepdim=True)
87
+ elif target_channels == 2:
88
+ wav = wav.expand(*shape, target_channels, length)
89
+ elif channels == 1:
90
+ wav = wav.expand(target_channels, -1)
91
+ else:
92
+ raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}")
93
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
94
+ return wav
95
+
96
+
97
+ def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
98
+ sample_rate: int, rescale: bool = False):
99
+ limit = 0.99
100
+ mx = wav.abs().max()
101
+ if rescale:
102
+ wav = wav * min(limit / mx, 1)
103
+ else:
104
+ wav = wav.clamp(-limit, limit)
105
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)