File size: 2,881 Bytes
9c0c5c8
 
1dea888
f30bb70
9c0c5c8
 
f30bb70
1dea888
9c0c5c8
1dea888
9c0c5c8
c17b696
9c0c5c8
f30bb70
 
 
9c0c5c8
 
c17b696
 
 
9c0c5c8
 
c17b696
 
9c0c5c8
1dea888
9c0c5c8
1dea888
9c0c5c8
 
 
 
 
 
 
 
c17b696
 
f30bb70
 
 
 
 
1dea888
 
 
c17b696
 
 
 
 
 
 
9c0c5c8
f30bb70
 
 
1dea888
 
c17b696
 
 
 
 
1dea888
 
 
e97d748
 
9c0c5c8
 
 
1dea888
c17b696
 
9c0c5c8
 
 
1dea888
e97d748
9c0c5c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import re
import io
import logging
import argparse

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, Features, Image, Value

from audiodiffusion.mel import Mel

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger('audio_to_images')


def main(args):
    mel = Mel(x_res=args.resolution,
              y_res=args.resolution,
              hop_length=args.hop_length)
    os.makedirs(args.output_dir, exist_ok=True)
    audio_files = [
        os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
        for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
    ]
    examples = []
    try:
        for audio_file in tqdm(audio_files):
            try:
                mel.load_audio(audio_file)
            except KeyboardInterrupt:
                raise
            except:
                continue
            for slice in range(mel.get_number_of_slices()):
                image = mel.audio_slice_to_image(slice)
                assert (image.width == args.resolution
                        and image.height == args.resolution)
                # skip completely silent slices
                if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
                    logger.warn('File %s slice %d is completely silent',
                                audio_file, slice)
                    continue
                with io.BytesIO() as output:
                    image.save(output, format="PNG")
                    bytes = output.getvalue()
                examples.extend([{
                    "image": {
                        "bytes": bytes
                    },
                    "audio_file": audio_file,
                    "slice": slice,
                }])
    finally:
        if len(examples) == 0:
            logger.warn('No valid audio files were found.')
            return
        ds = Dataset.from_pandas(
            pd.DataFrame(examples),
            features=Features({
                "image": Image(),
                "audio_file": Value(dtype="string"),
                "slice": Value(dtype="int16"),
            }),
        )
        dsd = DatasetDict({"train": ds})
        dsd.save_to_disk(os.path.join(args.output_dir))
        if args.push_to_hub:
            dsd.push_to_hub(args.push_to_hub)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description=
        "Create dataset of Mel spectrograms from directory of audio files.")
    parser.add_argument("--input_dir", type=str)
    parser.add_argument("--output_dir", type=str, default="data")
    parser.add_argument("--resolution", type=int, default=256)
    parser.add_argument("--hop_length", type=int, default=512)
    parser.add_argument("--push_to_hub", type=str, default=None)
    args = parser.parse_args()
    main(args)