Spaces:
Runtime error
Runtime error
from datetime import datetime | |
from miditok import Event, MIDILike | |
import os | |
import json | |
from time import perf_counter | |
from joblib import Parallel, delayed | |
from zipfile import ZipFile, ZIP_DEFLATED | |
from scipy.io.wavfile import write | |
import numpy as np | |
from pydub import AudioSegment | |
import shutil | |
def writeToFile(path, content): | |
if type(content) is dict: | |
with open(f"{path}", "w") as json_file: | |
json.dump(content, json_file) | |
else: | |
if type(content) is not str: | |
content = str(content) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
with open(path, "w") as f: | |
f.write(content) | |
# Function to read from text from txt file: | |
def readFromFile(path, isJSON=False): | |
with open(path, "r") as f: | |
if isJSON: | |
return json.load(f) | |
else: | |
return f.read() | |
def chain(input, funcs, *params): | |
res = input | |
for func in funcs: | |
try: | |
res = func(res, *params) | |
except TypeError: | |
res = func(res) | |
return res | |
def to_beat_str(value, beat_res=8): | |
values = [ | |
int(int(value * beat_res) / beat_res), | |
int(int(value * beat_res) % beat_res), | |
beat_res, | |
] | |
return ".".join(map(str, values)) | |
def to_base10(beat_str): | |
integer, decimal, base = split_dots(beat_str) | |
return integer + decimal / base | |
def split_dots(value): | |
return list(map(int, value.split("."))) | |
def compute_list_average(l): | |
return sum(l) / len(l) | |
def get_datetime(): | |
return datetime.now().strftime("%Y%m%d_%H%M%S") | |
def get_text(event): | |
match event.type: | |
case "Piece-Start": | |
return "PIECE_START " | |
case "Track-Start": | |
return "TRACK_START " | |
case "Track-End": | |
return "TRACK_END " | |
case "Instrument": | |
return f"INST={event.value} " | |
case "Bar-Start": | |
return "BAR_START " | |
case "Bar-End": | |
return "BAR_END " | |
case "Time-Shift": | |
return f"TIME_SHIFT={event.value} " | |
case "Note-On": | |
return f"NOTE_ON={event.value} " | |
case "Note-Off": | |
return f"NOTE_OFF={event.value} " | |
case _: | |
return "" | |
def get_event(text, value=None): | |
match text: | |
case "PIECE_START": | |
return Event("Piece-Start", value) | |
case "TRACK_START": | |
return None | |
case "TRACK_END": | |
return None | |
case "INST": | |
return Event("Instrument", value) | |
case "BAR_START": | |
return Event("Bar-Start", value) | |
case "BAR_END": | |
return Event("Bar-End", value) | |
case "TIME_SHIFT": | |
return Event("Time-Shift", value) | |
case "TIME_DELTA": | |
return Event("Time-Shift", to_beat_str(int(value) / 4)) | |
case "NOTE_ON": | |
return Event("Note-On", value) | |
case "NOTE_OFF": | |
return Event("Note-Off", value) | |
case _: | |
return None | |
# TODO: Make this singleton | |
def get_miditok(): | |
pitch_range = range(0, 140) # was (21, 109) | |
beat_res = {(0, 400): 8} | |
return MIDILike(pitch_range, beat_res) | |
class WriteTextMidiToFile: # utils saving to file | |
def __init__(self, generate_midi, output_path): | |
self.generated_midi = generate_midi.generated_piece | |
self.output_path = output_path | |
self.hyperparameter_and_bars = generate_midi.piece_by_track | |
def hashing_seq(self): | |
self.current_time = get_datetime() | |
self.output_path_filename = f"{self.output_path}/{self.current_time}.json" | |
def wrapping_seq_hyperparameters_in_dict(self): | |
# assert type(self.generated_midi) is str, "error: generate_midi must be a string" | |
# assert ( | |
# type(self.hyperparameter_dict) is dict | |
# ), "error: feature_dict must be a dictionnary" | |
return { | |
"generate_midi": self.generated_midi, | |
"hyperparameters_and_bars": self.hyperparameter_and_bars, | |
} | |
def text_midi_to_file(self): | |
self.hashing_seq() | |
output_dict = self.wrapping_seq_hyperparameters_in_dict() | |
print(f"Token generate_midi written: {self.output_path_filename}") | |
writeToFile(self.output_path_filename, output_dict) | |
return self.output_path_filename | |
def get_files(directory, extension, recursive=False): | |
""" | |
Given a directory, get a list of the file paths of all files matching the | |
specified file extension. | |
directory: the directory to search as a Path object | |
extension: the file extension to match as a string | |
recursive: whether to search recursively in the directory or not | |
""" | |
if recursive: | |
return list(directory.rglob(f"*.{extension}")) | |
else: | |
return list(directory.glob(f"*.{extension}")) | |
def timeit(func): | |
def wrapper(*args, **kwargs): | |
start = perf_counter() | |
result = func(*args, **kwargs) | |
end = perf_counter() | |
print(f"{func.__name__} took {end - start:.2f} seconds to run.") | |
return result | |
return wrapper | |
class FileCompressor: | |
def __init__(self, input_directory, output_directory, n_jobs=-1): | |
self.input_directory = input_directory | |
self.output_directory = output_directory | |
self.n_jobs = n_jobs | |
# File compression and decompression | |
def unzip_file(self, file): | |
"""uncompress single zip file""" | |
with ZipFile(file, "r") as zip_ref: | |
zip_ref.extractall(self.output_directory) | |
def zip_file(self, file): | |
"""compress a single text file to a new zip file and delete the original""" | |
output_file = self.output_directory / (file.stem + ".zip") | |
with ZipFile(output_file, "w") as zip_ref: | |
zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED) | |
file.unlink() | |
def unzip(self): | |
"""uncompress all zip files in folder""" | |
files = get_files(self.input_directory, extension="zip") | |
Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files) | |
def zip(self): | |
"""compress all text files in folder to new zip files and remove the text files""" | |
files = get_files(self.output_directory, extension="txt") | |
Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files) | |
def load_jsonl(filepath): | |
"""Load a jsonl file""" | |
with open(filepath, "r") as f: | |
data = [json.loads(line) for line in f] | |
return data | |
def write_mp3(waveform, output_path, bitrate="92k"): | |
""" | |
Write a waveform to an mp3 file. | |
output_path: Path object for the output mp3 file | |
waveform: numpy array of the waveform | |
bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k) | |
""" | |
# write the wav file | |
wav_path = output_path.with_suffix(".wav") | |
write(wav_path, 44100, waveform.astype(np.float32)) | |
# compress the wav file as mp3 | |
AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate) | |
# remove the wav file | |
wav_path.unlink() | |
def copy_file(input_file, output_dir): | |
"""Copy an input file to the output_dir""" | |
output_file = output_dir / input_file.name | |
shutil.copy(input_file, output_file) | |
def index_has_substring(list, substring): | |
for i, s in enumerate(list): | |
if substring in s: | |
return i | |
return -1 | |