Spaces:
Running
Running
import torch | |
import torchaudio | |
import os | |
import soundfile as sf | |
import librosa | |
from utils import unpack_sequence, token_seg_list_to_midi | |
from train import LitTranscriber | |
from utils import rms_normalize_wav | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # backend/src を指す | |
PTH_PATH = os.path.join(BASE_DIR, "model.pth") # ✅ .pth に変更 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
args = { | |
"n_mels": 128, | |
"sample_rate": 16000, | |
"n_fft": 1024, | |
"hop_length": 128, | |
} | |
model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99) | |
state_dict = torch.load(PTH_PATH, map_location=device) # ✅ .pthをロード | |
model.load_state_dict(state_dict) | |
#model.to(device) # ✅ デバイスに転送 | |
model.eval() | |
return model | |
def convert_to_pcm_wav(input_path, output_path): | |
# librosaで読み込み(自動的にPCM形式に変換される) | |
y, sr = librosa.load(input_path, sr=16000, mono=True) | |
sf.write(output_path, y, sr) | |
def infer_midi_from_wav(input_wav_path: str) -> str: | |
model = load_model() | |
converted_path = os.path.join(BASE_DIR, "converted_input.wav") | |
convert_to_pcm_wav(input_wav_path, converted_path) | |
normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav") | |
rms_normalize_wav(converted_path, normalized_path, target_rms=0.1) | |
waveform, sr = torchaudio.load(normalized_path) | |
waveform = waveform.mean(0).to(device) | |
if sr != model.transcriber.sr: | |
waveform = torchaudio.functional.resample( | |
waveform, sr, model.transcriber.sr | |
).to(device) | |
with torch.no_grad(): | |
output_tokens = model(waveform) | |
unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy()) | |
unpadded_tokens = [t[1:] for t in unpadded_tokens] | |
est_midi = token_seg_list_to_midi(unpadded_tokens) | |
midi_path = os.path.join(BASE_DIR, "output.mid") | |
est_midi.write(midi_path) | |
print(f"MIDI saved at: {midi_path}") | |
return midi_path | |