Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import sys | |
import torch | |
import json | |
import numpy as np | |
from omegaconf import OmegaConf | |
from codeclm.trainer.codec_song_pl import CodecLM_PL | |
from codeclm.models import CodecLM | |
from codeclm.models import builders | |
from separator import Separator | |
class LeVoInference(torch.nn.Module): | |
def __init__(self, ckpt_path): | |
super().__init__() | |
torch.backends.cudnn.enabled = False | |
OmegaConf.register_new_resolver("eval", lambda x: eval(x)) | |
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) | |
OmegaConf.register_new_resolver("get_fname", lambda: 'default') | |
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) | |
cfg_path = os.path.join(ckpt_path, 'config.yaml') | |
self.pt_path = os.path.join(ckpt_path, 'model.pt') | |
self.cfg = OmegaConf.load(cfg_path) | |
self.cfg.mode = 'inference' | |
self.max_duration = self.cfg.max_dur | |
self.default_params = dict( | |
top_p = 0.0, | |
record_tokens = True, | |
record_window = 50, | |
extend_stride = 5, | |
duration = self.max_duration, | |
) | |
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()): | |
if prompt_audio_path is not None and os.path.exists(prompt_audio_path): | |
separator = Separator() | |
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) | |
audio_tokenizer = audio_tokenizer.eval().cuda() | |
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) | |
seperate_tokenizer = seperate_tokenizer.eval().cuda() | |
pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path) | |
pmt_wav = pmt_wav.cuda() | |
vocal_wav = vocal_wav.cuda() | |
bgm_wav = bgm_wav.cuda() | |
pmt_wav, _ = audio_tokenizer.encode(pmt_wav) | |
vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav) | |
melody_is_wav = False | |
melody_is_wav = False | |
del audio_tokenizer | |
del seperate_tokenizer | |
del separator | |
elif genre is not None and auto_prompt_path is not None: | |
auto_prompt = torch.load(auto_prompt_path) | |
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] | |
if genre == "Auto": | |
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] | |
else: | |
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] | |
pmt_wav = prompt_token[:,[0],:] | |
vocal_wav = prompt_token[:,[1],:] | |
bgm_wav = prompt_token[:,[2],:] | |
melody_is_wav = False | |
else: | |
pmt_wav = None | |
vocal_wav = None | |
bgm_wav = None | |
melody_is_wav = True | |
model_light = CodecLM_PL(self.cfg, self.pt_path) | |
model_light = model_light.eval() | |
model_light.audiolm.cfg = self.cfg | |
model = CodecLM(name = "tmp", | |
lm = model_light.audiolm, | |
audiotokenizer = None, | |
max_duration = self.max_duration, | |
seperate_tokenizer = None, | |
) | |
del model_light | |
model.lm = model.lm.cuda().to(torch.float16) | |
params = {**self.default_params, **params} | |
model.set_generation_params(**params) | |
generate_inp = { | |
'lyrics': [lyric.replace(" ", " ")], | |
'descriptions': [description], | |
'melody_wavs': pmt_wav, | |
'vocal_wavs': vocal_wav, | |
'bgm_wavs': bgm_wav, | |
'melody_is_wav': melody_is_wav, | |
} | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
tokens = model.generate(**generate_inp, return_tokens=True) | |
del model | |
torch.cuda.empty_cache() | |
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) | |
seperate_tokenizer = seperate_tokenizer.eval().cuda() | |
model = CodecLM(name = "tmp", | |
lm = None, | |
audiotokenizer = None, | |
max_duration = self.max_duration, | |
seperate_tokenizer = seperate_tokenizer, | |
) | |
if tokens.shape[-1] > 3000: | |
tokens = tokens[..., :3000] | |
with torch.no_grad(): | |
if melody_is_wav: | |
wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) | |
else: | |
wav_seperate = model.generate_audio(tokens) | |
del seperate_tokenizer | |
del model | |
torch.cuda.empty_cache() | |
return wav_seperate[0] | |