Estella / inference.py
Ocillus's picture
Upload 7 files
76da613 verified
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy import signal
from PIL import Image
def create_spectrogram(wav_file, output_folder):
# Read the wav file
sample_rate, data = wavfile.read(wav_file)
# Convert data to mono if it's stereo
if data.ndim == 2:
data = np.mean(data, axis=1)
# Create the spectrogram
frequencies, times, spectrogram = signal.spectrogram(data, sample_rate)
# Create the figure
plt.figure(figsize=(10, 4))
# Plot the spectrogram
plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram), shading='gouraud', cmap='inferno')
# Set the y-axis limit to the Nyquist frequency
plt.ylim(0, sample_rate / 2)
# Remove axes and labels
plt.axis('off')
# Get the current axis
ax = plt.gca()
# Set the x-axis limits to start from 0 to the last time point
ax.set_xlim(0, times[-1])
# Fill the area to the right of the spectrogram with black
ax.add_patch(plt.Rectangle((times[-1], 0), 10, sample_rate / 2, facecolor='black', edgecolor='none'))
# Save the spectrogram image
filename = os.path.splitext(os.path.basename(wav_file))[0] + '.png'
plt.savefig(os.path.join(output_folder, filename), bbox_inches='tight', pad_inches=0)
plt.close()
# Convert white pixels to black
convert_white_to_black(os.path.join(output_folder, filename))
def convert_white_to_black(image_path):
# Open the image
img = Image.open(image_path)
# Convert the image to RGB (if not already in that mode)
img = img.convert("RGB")
# Get the data of the image
data = np.array(img)
# Create a mask for white pixels
white_pixels = (data[:, :, 0] == 255) & (data[:, :, 1] == 255) & (data[:, :, 2] == 255)
# Change white pixels to black
data[white_pixels] = [0, 0, 0]
# Create a new image from the modified data
new_img = Image.fromarray(data)
# Save the modified image
new_img.save(image_path)
def convert_wav_to_spectrograms(input_folder, output_folder):
# Create output folder if it doesn't exist
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Iterate through all files in the input folder
for file in os.listdir(input_folder):
if file.endswith('.wav'):
wav_file_path = os.path.join(input_folder, file)
create_spectrogram(wav_file_path, output_folder)
print(f"Converted {file} to spectrogram.")
if __name__ == "__main__":
input_folder = 'dataset' # Input folder containing WAV files
output_folder = 'spectrograms' # Output folder for spectrogram images
convert_wav_to_spectrograms(input_folder, output_folder)