humtrans / utils.py
hayaton0005's picture
Upload 11 files
c094356 verified
from operator import attrgetter
import dataclasses
import numpy as np
import pretty_midi as pm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from nnAudio.features import CQT
import soundfile as sf
from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC
from config import voc_single_track
@dataclasses.dataclass
class Event:
prog: int
onset: bool
pitch: int
class MIDITokenExtractor:
"""
・MIDIデータ(音符、タイミング、ペダル情報など)を抽出し、トークン列に変換する。
・セグメント単位でMIDIデータを分割し、各セグメントをトークンとして表現。
"""
def __init__(self, midi_path, voc_dict, apply_pedal=True):
"""
・MIDIデータを読み込み、必要に応じてサステインペダルを適用する初期化処理を行う。
"""
self.pm = pm.PrettyMIDI(midi_path) # MIDIファイルをPrettyMIDIで読み込む
if apply_pedal:
self.pm_apply_pedal(self.pm) # サステインペダル処理を適用
self.voc_dict = voc_dict # トークンの定義辞書
self.multi_track = "instrument" in voc_dict # マルチトラック対応のフラグ
def pm_apply_pedal(self, pm: pm.PrettyMIDI, program=0):
"""
Apply sustain pedal by stretching the notes in the pm object.
"""
# 1: Record the onset positions of each notes as a dictionary
onset_dict = dict()
for note in pm.instruments[program].notes:
if note.pitch in onset_dict:
onset_dict[note.pitch].append(note.start)
else:
onset_dict[note.pitch] = [note.start]
for k in onset_dict.keys():
onset_dict[k] = np.sort(onset_dict[k])
# 2: Record the pedal on/off state of each time frame
arr_pedal = np.zeros(
round(pm.get_end_time()*FRAME_PER_SEC)+100, dtype=bool)
pedal_on_time = -1
list_pedaloff_time = []
for cc in pm.instruments[program].control_changes:
if cc.number == 64:
if (cc.value > 0) and (pedal_on_time < 0):
pedal_on_time = round(cc.time*FRAME_PER_SEC)
elif (cc.value == 0) and (pedal_on_time >= 0):
pedal_off_time = round(cc.time*FRAME_PER_SEC)
arr_pedal[pedal_on_time:pedal_off_time] = True
list_pedaloff_time.append(cc.time)
pedal_on_time = -1
list_pedaloff_time = np.sort(list_pedaloff_time)
# 3: Stretch the notes (modify note.end)
for note in pm.instruments[program].notes:
# 3-1: Determine whether sustain pedal is on at note.end. If not, do nothing.
# 3-2: Find the next note onset time and next pedal off time after note.end.
# 3-3: Extend note.end till the minimum of next_onset and next_pedaloff.
note_off_frame = round(note.end*FRAME_PER_SEC)
pitch = note.pitch
if arr_pedal[note_off_frame]:
next_onset = np.argwhere(onset_dict[pitch] > note.end)
next_onset = np.inf if len(
next_onset) == 0 else onset_dict[pitch][next_onset[0, 0]]
next_pedaloff = np.argwhere(list_pedaloff_time > note.end)
next_pedaloff = np.inf if len(
next_pedaloff) == 0 else list_pedaloff_time[next_pedaloff[0, 0]]
new_noteoff_time = max(note.end, min(next_onset, next_pedaloff))
new_noteoff_time = min(new_noteoff_time, pm.get_end_time())
note.end = new_noteoff_time
def get_segment_tokens(self, start, end):
"""
Transform a segment of the MIDI file into a sequence of tokens.
"""
dict_event = dict() # a dictionary that maps time to a list of events.
def append_to_dict_event(time, item):
if time in dict_event:
dict_event[time].append(item)
else:
dict_event[time] = [item]
list_events = [] # events section
list_tie_section = [] # tie section
for instrument in self.pm.instruments:
prog = instrument.program
for note in instrument.notes:
note_end = round(note.end * FRAME_PER_SEC) # 音符終了時刻(フレーム単位)
note_start = round(note.start * FRAME_PER_SEC) # 音符開始時刻(フレーム単位)
if (note_end < start) or (note_start >= end):
# セグメント外の音符は無視
continue
if (note_start < start) and (note_end >= start):
# If the note starts before the segment, but ends in the segment
# it is added to the tie section.
# セグメント開始時より前に始まり、セグメント内で終了する音符(ピッチ)をタイセクションに追加
list_tie_section.append(self.voc_dict["note"] + note.pitch)
if note_end < end:
# セグメント内で終了する場合、イベントに終了時刻を記録
append_to_dict_event(
note_end - start, Event(prog, False, note.pitch)
)
continue
assert note_start >= start
# セグメント内で開始
append_to_dict_event(note_start - start, Event(prog, True, note.pitch))
if note_end < end:
# セグメント内で終了
append_to_dict_event(
note_end - start, Event(prog, False, note.pitch)
)
cur_onset = None
cur_prog = -1
for time in sorted(dict_event.keys()): # 現在の相対時間(time)をトークン化し、イベント列に追加
list_events.append(self.voc_dict["time"] + time) # self.voc_dict["time"]はtimeトークンの開始ID(133)。これに相対時間を足す
for event in sorted(dict_event[time], key=attrgetter("pitch", "onset")): # 同一時間内でピッチを昇順に、onset→offsetの順になるようにソートする。
if cur_onset != event.onset:
cur_onset = event.onset # オンセットオフセットが変わった場合に更新
list_events.append(self.voc_dict["onset"] + int(event.onset)) # オンセットオフセットトークンを追加
list_events.append(self.voc_dict["note"] + event.pitch) # 音符トークンを追加
# Concatenate tie section, endtie token, and event section
list_tie_section.append(self.voc_dict["endtie"]) # ID:2を追加
list_events.append(self.voc_dict["eos"]) # ID: 1を追加
tokens = np.concatenate((list_tie_section, list_events)).astype(int)
return tokens
"""## Detokenizer
Transforms a list of MIDI-like token sequences into a MIDI file.
"""
def parse_id(voc_dict: dict, id: int):
"""
トークンIDを解析し、トークンの種類名とその相対IDを返す関数。
"""
keys = voc_dict["keylist"] # トークンの種類リストを取得
token_name = keys[0] # デフォルトの種類を "pad" に設定
# トークンの種類を特定する
for k in keys:
if id < voc_dict[k]: # 現在の種類の開始位置より小さい場合、前の種類が該当
break
token_name = k # 現在の種類名を更新
# 該当種類内での相対IDを計算
token_id = id - voc_dict[token_name]
return token_name, token_id # 種類名と相対IDを返す
def to_second(n):
"""
フレーム数を秒単位に変換する関数。
"""
return n * FRAME_STEP_SIZE_SEC
def find_note(list, n):
"""
タプルリストの最初の要素から指定された値を検索し、そのインデックスを返す関数。
"""
li_elem = [a for a, _ in list] # 最初の要素だけを抽出したリスト
try:
idx = li_elem.index(n) # n が存在する場合、そのインデックスを取得
except ValueError:
return -1 # 存在しない場合は -1 を返す
return idx # 見つかった場合、そのインデックスを返す
def token_seg_list_to_midi(token_seg_list: list):
"""
トークン列リストをMIDIファイルに変換する関数。
"""
# MIDIデータと楽器の初期化
midi_data = pm.PrettyMIDI()
piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
piano = pm.Instrument(program=piano_program)
list_onset = [] # 開始時刻を記録するリスト ※次のセグメント処理を行うときにlist_onsetには、最終音とタイの可能性がある音が記録されている
cur_time = 0 # 前回のセグメントの終了時間
# トークンセグメントごとの処理
for token_seg in token_seg_list:
list_tie = [] # タイ結合された音符
cur_relative_time = -1 # セグメント内の相対時間
cur_onset = -1 # 現在のオンセット状態
tie_end = False # タイ結合が終了したかどうか
for token in token_seg:
# トークンを解析
token_name, token_id = parse_id(voc_single_track, token)
if token_name == "note":
# 音符処理
if not tie_end: # タイ結合の場合
list_tie.append(token_id) # タイのnote番号(相対ID)を追加
elif cur_onset == 1:
list_onset.append((token_id, cur_time + cur_relative_time)) # 開始時刻を記録
elif cur_onset == 0:
# 終了処理
i = find_note(list_onset, token_id)
if i >= 0:
start = list_onset[i][1]
end = cur_time + cur_relative_time
if start < end: # 開始時刻 < 終了時刻の場合のみ追加
new_note = pm.Note(100, token_id, start, end)
piano.notes.append(new_note)
list_onset.pop(i)
elif token_name == "onset":
# オンセット/オフセットの更新
if tie_end:
if token_id == 1:
cur_onset = 1 # 開始
elif token_id == 0:
cur_onset = 0 # 終了
elif token_name == "time":
# 相対時間の更新
if tie_end:
cur_relative_time = to_second(token_id)
elif token_name == "endtie":
# タイ結合終了処理
tie_end = True
for note, start in list_onset: # list_onsetには前回のセグメントで未処理の最終音が含まれる
if note not in list_tie: # list_onsetにあるにも関わらず、list_tieにない場合
if start < cur_time:
new_note = pm.Note(100, note, start, cur_time) # 前回のセグメントの終了時をendとする(end=cur_time)
piano.notes.append(new_note)
list_onset.remove((note, start))
# 現在の時間を更新
cur_time += AUDIO_SEGMENT_SEC
# 楽器をMIDIデータに追加
midi_data.instruments.append(piano)
return midi_data
# 固定長のセグメントに分割する関数
def split_audio_into_segments(y: torch.Tensor, sr: int): # オーディオデータ(テンソル), オーディオのサンプルレート
audio_segment_samples = round(AUDIO_SEGMENT_SEC * sr) # 1セグメントの長さをサンプル数で計算
pad_size = audio_segment_samples - (y.shape[-1] % audio_segment_samples) # セグメントサイズできっちり分割できるようにpadするサイズを計算
y = F.pad(y, (0, pad_size)) # padを追加
assert (y.shape[-1] % audio_segment_samples) == 0 # 割り切れない場合assertをする
n_chunks = y.shape[-1] // audio_segment_samples # セグメント数を計算
# 固定長のセグメントに分割
y_segments = torch.chunk(y, chunks=n_chunks, dim=-1) # torch.chunk: テンソルを指定した数(n_chunks)に分割、dim=-1: サンプル次元(最後の次元)で分割
return torch.stack(y_segments, dim=0) # 分割したセグメントを1つのテンソルにまとめる。dim=0: セグメント数をバッチ次元として結合。
# 形状: (セグメント数, セグメント長)。
# 推論時に作られたpadされたseqから有効な部分を抽出する関数
def unpack_sequence(x: torch.Tensor, eos_id: int=1):
seqs = [] # 各シーケンスを切り出して保存するリストを初期化。
max_length = x.shape[-1] # シーケンスの最大長を取得。whileループの範囲をチェックするために使用。
for seq in x: # テンソル x の各シーケンス(行)を処理
# eosトークンを探す
start_pos = 0
pos = 0
while (pos < max_length) and (seq[pos] != eos_id): # 現在地が最大長を超えない&現在地がeosではない場合
pos += 1
# ループ終了後:pos には、終了トークン(eos_id)の位置(またはシーケンスの末尾)が格納されます。
end_pos = pos+1
seqs.append(seq[start_pos:end_pos]) #開始位置(start_pos)から終了位置(end_pos)までの部分を切り出し、リスト seqs に追加。
return seqs
class LogMelspec(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, hop_length):
super().__init__()
self.melspec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft, # FFT(高速フーリエ変換)に使用するポイント数
hop_length=hop_length, # ストライド(フレームのシフト幅(サンプル単位))。通常、n_fft // 4 などの値を設定
f_min=20.0, # メルフィルタバンクの最小周波数。20Hz 以上が推奨(人間の聴覚範囲)。
n_mels=n_mels, # メルスペクトログラムの周波数軸方向の次元数(メルバンド数)。
mel_scale="slaney", # メル尺度を計算する方法."slaney": より音響的に意味のあるスケール
norm="slaney", # メルフィルタバンクの正規化方法. "slaney" を指定するとフィルタバンクがエネルギーで正規化
power=1, # 出力スペクトログラムのエネルギースケール, 1: 振幅スペクトログラム。
)
self.eps = 1e-5 # 対数計算時のゼロ除算エラーを防ぐための閾値。メルスペクトログラムの値が 1e-5 未満の場合、この値に置き換える。
def forward(self, x): # 入力: モノラル: (バッチサイズ, サンプル数) or ステレオ: (バッチサイズ, チャンネル数, サンプル数)
spec = self.melspec(x) # 入力波形 x からメルスペクトログラムを計算, 出力:メルスペクトログラム: (バッチサイズ, メルバンド数, フレーム数)
safe_spec = torch.clamp(spec, min=self.eps) # メルスペクトログラムの最小値を self.eps に制限。値が非常に小さい場合でも、対数計算が可能。
log_spec = torch.log(safe_spec) #メルスペクトログラムを対数スケールに変換。
return log_spec # (バッチサイズ, メルバンド数, フレーム数) のテンソル。各値は対数スケールのメルスペクトログラム。
class LogCQT(nn.Module):
def __init__(self, sample_rate, n_bins, hop_length, bins_per_octave):
super().__init__()
self.cqt = CQT(
sr=sample_rate,
hop_length=hop_length,
fmin=32.7, # 低周波数の最小値 (通常は32.7Hz, C1)
fmax=8000, # 最高周波数
n_bins=n_bins,
bins_per_octave=bins_per_octave,
verbose=False
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに載せる
self.eps = 1e-5 # ゼロ除算を防ぐ閾値
def forward(self, x):
x = x.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに送る
cqt_spec = self.cqt(x) # (B, n_bins, time)
safe_spec = torch.clamp(cqt_spec, min=self.eps) # 小さな値をカット
log_cqt = torch.log(safe_spec) # 対数変換
return log_cqt # (B, n_bins, time)
def normalize_rms_torch(audio_tensor, target_rms=0.1):
rms = torch.sqrt(torch.mean(audio_tensor**2)).item()
if rms < 1e-6:
print("音が非常に小さいため、最小値でスケーリングします")
rms = 1e-6
scaling_factor = target_rms / rms
return audio_tensor * scaling_factor
def rms_normalize_wav(input_path, output_path, target_rms=0.1):
waveform, sr = torchaudio.load(input_path)
waveform = waveform.mean(0, keepdim=True) # モノラル化
normalized = normalize_rms_torch(waveform, target_rms)
torchaudio.save(output_path, normalized, sample_rate=sr)