Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,041 Bytes
78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
###
# Author: Kai Li
# Date: 2021-06-18 16:32:50
# LastEditors: Kai Li
# LastEditTime: 2021-06-19 01:02:04
###
import os
import warnings
import torch
import numpy as np
import soundfile as sf
def get_device(tensor_or_module, default=None):
if hasattr(tensor_or_module, "device"):
return tensor_or_module.device
elif hasattr(tensor_or_module, "parameters"):
return next(tensor_or_module.parameters()).device
elif default is None:
raise TypeError(
f"Don't know how to get device of {type(tensor_or_module)} object"
)
else:
return torch.device(default)
class Separator:
def forward_wav(self, wav, **kwargs):
raise NotImplementedError
def sample_rate(self):
raise NotImplementedError
def separate(model, wav, **kwargs):
if isinstance(wav, np.ndarray):
return numpy_separate(model, wav, **kwargs)
elif isinstance(wav, torch.Tensor):
return torch_separate(model, wav, **kwargs)
else:
raise ValueError(
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
)
@torch.no_grad()
def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor:
"""Core logic of `separate`."""
if model.in_channels is not None and wav.shape[-2] != model.in_channels:
raise RuntimeError(
f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
f"Please match the number of channels."
)
# Handle device placement
input_device = get_device(wav, default="cpu")
model_device = get_device(model, default="cpu")
wav = wav.to(model_device)
# Forward
separate_func = getattr(model, "forward_wav", model)
out_wavs = separate_func(wav, **kwargs)
# FIXME: for now this is the best we can do.
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
# Back to input device (and numpy if necessary)
out_wavs = out_wavs.to(input_device)
return out_wavs
def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray:
"""Numpy interface to `separate`."""
wav = torch.from_numpy(wav)
out_wavs = torch_separate(model, wav, **kwargs)
out_wavs = out_wavs.data.numpy()
return out_wavs
def wav_chunk_inference(model, mixture_tensor, sr=16000, target_length=12.0, hop_length=4.0, batch_size=10, n_tracks=3):
"""
Input:
mixture_tensor: Tensor, [nch, input_length]
Output:
all_target_tensor: Tensor, [nch, n_track, input_length]
"""
batch_mixture = mixture_tensor
# split data into segments
batch_length = batch_mixture.shape[-1]
session = int(sr * target_length)
target = int(sr * target_length)
ignore = (session - target) // 2
hop = int(sr * hop_length)
tr_ratio = target_length / hop_length
if ignore > 0:
zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], ignore).type(batch_mixture.type()).to(batch_mixture.device)
batch_mixture_pad = torch.cat([zero_pad, batch_mixture, zero_pad], -1)
else:
batch_mixture_pad = batch_mixture
if target - hop > 0:
hop_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], target-hop).type(batch_mixture.type()).to(batch_mixture.device)
batch_mixture_pad = torch.cat([hop_pad, batch_mixture_pad, hop_pad], -1)
skip_idx = ignore + target - hop
zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], session).type(batch_mixture.type()).to(batch_mixture.device)
num_session = (batch_mixture_pad.shape[-1] - session) // hop + 2
all_target = torch.zeros(batch_mixture_pad.shape[0], n_tracks, batch_mixture_pad.shape[1], batch_mixture_pad.shape[2]).to(batch_mixture_pad.device)
all_input = []
all_segment_length = []
for i in range(num_session):
this_input = batch_mixture_pad[:,:,i*hop:i*hop+session]
segment_length = this_input.shape[-1]
if segment_length < session:
this_input = torch.cat([this_input, zero_pad[:,:,:session-segment_length]], -1)
all_input.append(this_input)
all_segment_length.append(segment_length)
all_input = torch.cat(all_input, 0)
num_batch = num_session // batch_size
if num_session % batch_size > 0:
num_batch += 1
for i in range(num_batch):
this_input = all_input[i*batch_size:(i+1)*batch_size]
actual_batch_size = this_input.shape[0]
with torch.no_grad():
est_target = model(this_input)
# print(est_target.shape)
for j in range(actual_batch_size):
this_est_target = est_target[j,:,:,:all_segment_length[i*batch_size+j]][:,:,ignore:ignore+target].unsqueeze(0)
all_target[:,:,:,ignore+(i*batch_size+j)*hop:ignore+(i*batch_size+j)*hop+target] += this_est_target
all_target = all_target[:,:,:,skip_idx:skip_idx+batch_length].contiguous() / tr_ratio
return all_target.squeeze(0) |