humtrans / eval.py
hayaton0005's picture
Upload 11 files
c094356 verified
import numpy as np
import pretty_midi as pm
import mir_eval
"""# Evaluation function"""
def extract_midi(midi: pm.PrettyMIDI, program=0): # MIDIデータを読み込んだ PrettyMIDI オブジェクト, MIDIチャンネル(楽器番号)を指定。
intervals = [] # 音符ごとの開始時間と終了時間のペアを格納したNumPy配列
pitches = [] # 音符ごとの音高(MIDIノート番号)のNumPy配列。
pm_notes = midi.instruments[program].notes # programで指定された対象楽器に含まれる全てのノート情報を取得。
"""
例;
instruments = [
Instrument 0 (Piano): [Note(start=0.5, end=1.0, pitch=60), ...],
Instrument 1 (Violin): [Note(start=1.0, end=1.5, pitch=62), ...]
]
"""
# ノートを順番に処理
for note in pm_notes:
intervals.append((note.start, note.end)) # 音符の開始・終了時間のペアを intervals に追加。
pitches.append(note.pitch) # 音符の音高を pitches に追加。
return np.array(intervals), np.array(pitches) # intervals: 2D配列(各行が1つの音符の開始・終了時間を表す。), pitches: 1D配列(各要素が1つの音符の音高(ピッチ)を表す。)
def evaluate_midi(est_midi: pm.PrettyMIDI, ref_midi: pm.PrettyMIDI, program=0):
est_intervals, est_pitches = extract_midi(est_midi, program)
ref_intervals, ref_pitches = extract_midi(ref_midi, program)
# mir_eval ライブラリの transcription モジュールを使って、音符の一致度を評価します。
dict_eval = mir_eval.transcription.evaluate(
ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05)
return dict_eval # dict_eval: 評価結果の辞書。