import tqdm import torch from einops import rearrange def scalar_to_batch_tensor(x, batch_size): return torch.tensor(x).repeat(batch_size) def parallelize( fn, *iterables, parallel: str = "thread_map", **kwargs ): if parallel == "thread_map": from tqdm.contrib.concurrent import thread_map return thread_map( fn, *iterables, **kwargs ) elif parallel == "process_map": from tqdm.contrib.concurrent import process_map return process_map( fn, *iterables, **kwargs ) elif parallel == "single": return [fn(x) for x in tqdm.tqdm(*iterables)] else: raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}") def codebook_flatten(tokens: torch.Tensor): """ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time) """ return rearrange(tokens, "b c t -> b (t c)") def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None): """ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time) """ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c) return tokens