TirthMusicGen_Rebalanced / hf_loading.py
adefossez's picture
Initial commit
5238467
raw
history blame
2.03 kB
"""Utility for loading the models from HF."""
import os
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download, login
import torch
from audiocraft.models import builders, MusicGen
MODEL_CHECKPOINTS_MAP = {
"small": "facebook/musicgen-small",
"medium": "facebook/musicgen-medium",
"large": "facebook/musicgen-large",
"melody": "facebook/musicgen-melody",
}
login(os.environ['ACCESS_TOKEN'])
def _get_state_dict(file_or_url: tp.Union[Path, str],
filename="state_dict.bin", device='cpu'):
# Return the state dict either from a file or url
print("loading", file_or_url, filename)
file_or_url = str(file_or_url)
assert isinstance(file_or_url, str)
return torch.load(
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.transformer_lm.memory_efficient = False
cfg.transformer_lm.custom = True
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def get_pretrained(name: str = 'small', device='cuda'):
model_id = MODEL_CHECKPOINTS_MAP[name]
compression_model = load_compression_model(model_id, device=device)
lm = load_lm_model(model_id, device=device)
return MusicGen(name, compression_model, lm)