mikeross's picture
Duplicate from nateraw/deepafx-st
c983126
import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import librosa.display
import matplotlib.pyplot as plt
from torch.functional import Tensor
from torchvision.transforms import ToTensor
def compute_comparison_spectrogram(
x: np.ndarray,
y: np.ndarray,
sample_rate: float = 44100,
n_fft: int = 2048,
hop_length: int = 1024,
) -> Tensor:
X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)
Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)
fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
img = librosa.display.specshow(
X_db,
ax=axs[0],
hop_length=hop_length,
x_axis="time",
y_axis="log",
sr=sample_rate,
)
# fig.colorbar(img, ax=axs[0])
img = librosa.display.specshow(
Y_db,
ax=axs[1],
hop_length=hop_length,
x_axis="time",
y_axis="log",
sr=sample_rate,
)
# fig.colorbar(img, ax=axs[1])
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
plt.close("all")
return image
def plot_multi_spectrum(
ys=None,
Hs=None,
legend=[],
title="Spectrum",
filename=None,
sample_rate=44100,
n_fft=1024,
zero_mean=False,
):
if Hs is None:
Hs = []
for y in ys:
X = get_average_spectrum(y, n_fft)
X_sm = smooth_spectrum(X)
Hs.append(X_sm)
bin_width = (sample_rate / 2) / (n_fft // 2)
freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)
fig, ax1 = plt.subplots()
for idx, H in enumerate(Hs):
H = np.nan_to_num(H)
H = np.clip(H, 0, np.max(H))
H_dB = 20 * np.log10(H + 1e-8)
if zero_mean:
H_dB -= np.mean(H_dB)
if "Target" in legend[idx]:
ax1.plot(freqs, H_dB, linestyle="--", color="k")
else:
ax1.plot(freqs, H_dB)
plt.legend(legend)
ax1.set_xscale("log")
ax1.set_ylim([-80, 0])
ax1.set_xlim([100, 11000])
plt.title(title)
plt.ylabel("Magnitude (dB)")
plt.xlabel("Frequency (Hz)")
plt.grid(c="lightgray", which="both")
if filename is not None:
plt.savefig(f"{filename}.png", dpi=300)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
plt.close("all")
return image
def smooth_spectrum(H):
# apply Savgol filter for smoothed target curve
return scipy.signal.savgol_filter(H, 1025, 2)
def get_average_spectrum(x, n_fft):
X = torch.stft(x, n_fft, return_complex=True, normalized=True)
X = X.abs() # convert to magnitude
X = X.mean(dim=-1).view(-1) # average across frames
return X