Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez) | |
| import functools | |
| import logging | |
| from contextlib import contextmanager | |
| import inspect | |
| import os | |
| import time | |
| import math | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def capture_init(init): | |
| """ | |
| Decorate `__init__` with this, and you can then | |
| recover the *args and **kwargs passed to it in `self._init_args_kwargs` | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| self._init_args_kwargs = (args, kwargs) | |
| init(self, *args, **kwargs) | |
| return __init__ | |
| def deserialize_model(package, strict=False): | |
| klass = package['class'] | |
| if strict: | |
| model = klass(*package['args'], **package['kwargs']) | |
| else: | |
| sig = inspect.signature(klass) | |
| kw = package['kwargs'] | |
| for key in list(kw): | |
| if key not in sig.parameters: | |
| logger.warning("Dropping inexistant parameter %s", key) | |
| del kw[key] | |
| model = klass(*package['args'], **kw) | |
| model.load_state_dict(package['state']) | |
| return model | |
| def copy_state(state): | |
| return {k: v.cpu().clone() for k, v in state.items()} | |
| def serialize_model(model): | |
| args, kwargs = model._init_args_kwargs | |
| state = copy_state(model.state_dict()) | |
| return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} | |
| def swap_state(model, state): | |
| old_state = copy_state(model.state_dict()) | |
| model.load_state_dict(state) | |
| try: | |
| yield | |
| finally: | |
| model.load_state_dict(old_state) | |
| def swap_cwd(cwd): | |
| old_cwd = os.getcwd() | |
| os.chdir(cwd) | |
| try: | |
| yield | |
| finally: | |
| os.chdir(old_cwd) | |
| def pull_metric(history, name): | |
| out = [] | |
| for metrics in history: | |
| if name in metrics: | |
| out.append(metrics[name]) | |
| return out | |
| class LogProgress: | |
| """ | |
| Sort of like tqdm but using log lines and not as real time. | |
| """ | |
| def __init__(self, logger, iterable, updates=5, total=None, | |
| name="LogProgress", level=logging.INFO): | |
| self.iterable = iterable | |
| self.total = total or len(iterable) | |
| self.updates = updates | |
| self.name = name | |
| self.logger = logger | |
| self.level = level | |
| def update(self, **infos): | |
| self._infos = infos | |
| def __iter__(self): | |
| self._iterator = iter(self.iterable) | |
| self._index = -1 | |
| self._infos = {} | |
| self._begin = time.time() | |
| return self | |
| def __next__(self): | |
| self._index += 1 | |
| try: | |
| value = next(self._iterator) | |
| except StopIteration: | |
| raise | |
| else: | |
| return value | |
| finally: | |
| log_every = max(1, self.total // self.updates) | |
| # logging is delayed by 1 it, in order to have the metrics from update | |
| if self._index >= 1 and self._index % log_every == 0: | |
| self._log() | |
| def _log(self): | |
| self._speed = (1 + self._index) / (time.time() - self._begin) | |
| infos = " | ".join(f"{k.capitalize()} {v}" for k, | |
| v in self._infos.items()) | |
| if self._speed < 1e-4: | |
| speed = "oo sec/it" | |
| elif self._speed < 0.1: | |
| speed = f"{1/self._speed:.1f} sec/it" | |
| else: | |
| speed = f"{self._speed:.1f} it/sec" | |
| out = f"{self.name} | {self._index}/{self.total} | {speed}" | |
| if infos: | |
| out += " | " + infos | |
| self.logger.log(self.level, out) | |
| def colorize(text, color): | |
| code = f"\033[{color}m" | |
| restore = f"\033[0m" | |
| return "".join([code, text, restore]) | |
| def bold(text): | |
| return colorize(text, "1") | |
| def calculate_grad_norm(model): | |
| total_norm = 0.0 | |
| is_first = True | |
| for p in model.parameters(): | |
| param_norm = p.data.grad.flatten() | |
| if is_first: | |
| total_norm = param_norm | |
| is_first = False | |
| else: | |
| total_norm = torch.cat((total_norm.unsqueeze( | |
| 1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1) | |
| return total_norm.norm(2) ** (1. / 2) | |
| def calculate_weight_norm(model): | |
| total_norm = 0.0 | |
| is_first = True | |
| for p in model.parameters(): | |
| param_norm = p.data.flatten() | |
| if is_first: | |
| total_norm = param_norm | |
| is_first = False | |
| else: | |
| total_norm = torch.cat((total_norm.unsqueeze( | |
| 1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1) | |
| return total_norm.norm(2) ** (1. / 2) | |
| def remove_pad(inputs, inputs_lengths): | |
| """ | |
| Args: | |
| inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size | |
| inputs_lengths: torch.Tensor, [B] | |
| Returns: | |
| results: a list containing B items, each item is [C, T], T varies | |
| """ | |
| results = [] | |
| dim = inputs.dim() | |
| if dim == 3: | |
| C = inputs.size(1) | |
| for input, length in zip(inputs, inputs_lengths): | |
| if dim == 3: # [B, C, T] | |
| results.append(input[:, :length].view(C, -1).cpu().numpy()) | |
| elif dim == 2: # [B, T] | |
| results.append(input[:length].view(-1).cpu().numpy()) | |
| return results | |
| def overlap_and_add(signal, frame_step): | |
| """Reconstructs a signal from a framed representation. | |
| Adds potentially overlapping frames of a signal with shape | |
| `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. | |
| The resulting tensor has shape `[..., output_size]` where | |
| output_size = (frames - 1) * frame_step + frame_length | |
| Args: | |
| signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. | |
| frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. | |
| Returns: | |
| A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. | |
| output_size = (frames - 1) * frame_step + frame_length | |
| Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py | |
| """ | |
| outer_dimensions = signal.size()[:-2] | |
| frames, frame_length = signal.size()[-2:] | |
| # gcd=Greatest Common Divisor | |
| subframe_length = math.gcd(frame_length, frame_step) | |
| subframe_step = frame_step // subframe_length | |
| subframes_per_frame = frame_length // subframe_length | |
| output_size = frame_step * (frames - 1) + frame_length | |
| output_subframes = output_size // subframe_length | |
| subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) | |
| frame = torch.arange(0, output_subframes).unfold( | |
| 0, subframes_per_frame, subframe_step) | |
| frame = frame.clone().detach().long().to(signal.device) | |
| # frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU | |
| frame = frame.contiguous().view(-1) | |
| result = signal.new_zeros( | |
| *outer_dimensions, output_subframes, subframe_length) | |
| result.index_add_(-2, frame, subframe_signal) | |
| result = result.view(*outer_dimensions, -1) | |
| return result | |