|
import os |
|
import re |
|
import io |
|
import logging |
|
import argparse |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from tqdm.auto import tqdm |
|
from datasets import Dataset, DatasetDict, Features, Image, Value |
|
|
|
from audiodiffusion.mel import Mel |
|
|
|
logging.basicConfig(level=logging.WARN) |
|
logger = logging.getLogger('audio_to_images') |
|
|
|
|
|
def main(args): |
|
mel = Mel(x_res=args.resolution[0], |
|
y_res=args.resolution[1], |
|
hop_length=args.hop_length, |
|
sample_rate=args.sample_rate) |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
audio_files = [ |
|
os.path.join(root, file) for root, _, files in os.walk(args.input_dir) |
|
for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE) |
|
] |
|
examples = [] |
|
try: |
|
for audio_file in tqdm(audio_files): |
|
try: |
|
mel.load_audio(audio_file) |
|
except KeyboardInterrupt: |
|
raise |
|
except: |
|
continue |
|
for slice in range(mel.get_number_of_slices()): |
|
image = mel.audio_slice_to_image(slice) |
|
assert (image.width == args.resolution[0] and image.height |
|
== args.resolution[1]), "Wrong resolution" |
|
|
|
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255): |
|
logger.warn('File %s slice %d is completely silent', |
|
audio_file, slice) |
|
continue |
|
with io.BytesIO() as output: |
|
image.save(output, format="PNG") |
|
bytes = output.getvalue() |
|
examples.extend([{ |
|
"image": { |
|
"bytes": bytes |
|
}, |
|
"audio_file": audio_file, |
|
"slice": slice, |
|
}]) |
|
except Exception as e: |
|
print(e) |
|
finally: |
|
if len(examples) == 0: |
|
logger.warn('No valid audio files were found.') |
|
return |
|
ds = Dataset.from_pandas( |
|
pd.DataFrame(examples), |
|
features=Features({ |
|
"image": Image(), |
|
"audio_file": Value(dtype="string"), |
|
"slice": Value(dtype="int16"), |
|
}), |
|
) |
|
dsd = DatasetDict({"train": ds}) |
|
dsd.save_to_disk(os.path.join(args.output_dir)) |
|
if args.push_to_hub: |
|
dsd.push_to_hub(args.push_to_hub) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description= |
|
"Create dataset of Mel spectrograms from directory of audio files.") |
|
parser.add_argument("--input_dir", type=str) |
|
parser.add_argument("--output_dir", type=str, default="data") |
|
parser.add_argument("--resolution", |
|
type=str, |
|
default="256", |
|
help="Either square resolution or width,height.") |
|
parser.add_argument("--hop_length", type=int, default=512) |
|
parser.add_argument("--push_to_hub", type=str, default=None) |
|
parser.add_argument("--sample_rate", type=int, default=22050) |
|
args = parser.parse_args() |
|
|
|
if args.input_dir is None: |
|
raise ValueError( |
|
"You must specify an input directory for the audio files.") |
|
|
|
|
|
try: |
|
args.resolution = (int(args.resolution), int(args.resolution)) |
|
except ValueError: |
|
try: |
|
args.resolution = tuple(int(x) for x in args.resolution.split(",")) |
|
if len(args.resolution) != 2: |
|
raise ValueError |
|
except ValueError: |
|
raise ValueError( |
|
"Resolution must be a tuple of two integers or a single integer." |
|
) |
|
assert isinstance(args.resolution, tuple) |
|
|
|
main(args) |
|
|