|
import os |
|
import torch |
|
import logging |
|
import requests |
|
import zipfile |
|
|
|
from lib.models import ( |
|
SynthesizerTrnMs256NSFsid, |
|
SynthesizerTrnMs256NSFsid_nono, |
|
SynthesizerTrnMs768NSFsid, |
|
SynthesizerTrnMs768NSFsid_nono, |
|
) |
|
from src.config import Config |
|
from src.vc_infer_pipeline import VC |
|
|
|
logging.getLogger("fairseq").setLevel(logging.WARNING) |
|
|
|
|
|
class ModelLoader: |
|
def __init__(self): |
|
self.model_root = os.path.join(os.getcwd(), "weights") |
|
self.config = Config() |
|
self.model_list = [ |
|
d |
|
for d in os.listdir(self.model_root) |
|
if os.path.isdir(os.path.join(self.model_root, d)) |
|
] |
|
if len(self.model_list) == 0: |
|
raise ValueError("No model found in `weights` folder") |
|
|
|
self.model_name = "" |
|
self.model_list.sort() |
|
|
|
self.tgt_sr = None |
|
self.net_g = None |
|
self.vc = None |
|
self.version = None |
|
self.index_file = None |
|
self.if_f0 = None |
|
|
|
def _load_from_zip_url(self, url): |
|
response = requests.get(url) |
|
file_name = os.path.join( |
|
self.model_root, os.path.basename(url[: url.index(".zip") + 4]) |
|
) |
|
model_name = os.path.basename(file_name).replace(".zip", "") |
|
print(f"Extraacting Model: {model_name}") |
|
|
|
if response.status_code == 200: |
|
with open(file_name, "wb") as file: |
|
file.write(response.content) |
|
|
|
with zipfile.ZipFile(file_name, "r") as zip_ref: |
|
zip_ref.extractall(os.path.join(self.model_root, model_name)) |
|
os.remove(file_name) |
|
else: |
|
print("Could not download model: {model_name}") |
|
return model_name |
|
|
|
def load(self, model_name): |
|
if "http" in model_name: |
|
model_name = self._load_from_zip_url(model_name) |
|
|
|
pth_files = [ |
|
os.path.join(self.model_root, model_name, f) |
|
for f in os.listdir(os.path.join(self.model_root, model_name)) |
|
if f.endswith(".pth") |
|
] |
|
if len(pth_files) == 0: |
|
raise ValueError(f"No pth file found in {self.model_root}/{model_name}") |
|
|
|
self.model_name = model_name |
|
pth_path = pth_files[0] |
|
print(f"Loading {pth_path}, model: {model_name}") |
|
|
|
cpt = torch.load(pth_path, map_location="cpu") |
|
self.tgt_sr = cpt["config"][-1] |
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] |
|
self.if_f0 = cpt.get("f0", 1) |
|
self.version = cpt.get("version", "v1") |
|
|
|
if self.version == "v1": |
|
if self.if_f0 == 1: |
|
self.net_g = SynthesizerTrnMs256NSFsid( |
|
*cpt["config"], is_half=self.config.is_half |
|
) |
|
else: |
|
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
|
elif self.version == "v2": |
|
if self.if_f0 == 1: |
|
self.net_g = SynthesizerTrnMs768NSFsid( |
|
*cpt["config"], is_half=self.config.is_half |
|
) |
|
else: |
|
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) |
|
else: |
|
raise ValueError("Unknown version") |
|
|
|
del self.net_g.enc_q |
|
self.net_g.load_state_dict(cpt["weight"], strict=False) |
|
print("Model loaded") |
|
self.net_g.eval().to(self.config.device) |
|
|
|
if self.config.is_half: |
|
self.net_g = self.net_g.half() |
|
else: |
|
self.net_g = self.net_g.float() |
|
|
|
self.vc = VC(self.tgt_sr, self.config) |
|
|
|
index_files = [ |
|
os.path.join(self.model_root, model_name, f) |
|
for f in os.listdir(os.path.join(self.model_root, model_name)) |
|
if f.endswith(".index") |
|
] |
|
|
|
if len(index_files) == 0: |
|
print("No index file found") |
|
self.index_file = "" |
|
else: |
|
self.index_file = index_files[0] |
|
print(f"Index file found: {self.index_file}") |
|
|
|
def load_hubert(self): |
|
from fairseq import checkpoint_utils |
|
|
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
|
["weights/hubert_base.pt"], |
|
suffix="", |
|
) |
|
self.hubert_model = models[0] |
|
self.hubert_model = self.hubert_model.to(self.config.device) |
|
|
|
if self.config.is_half: |
|
self.hubert_model = self.hubert_model.half() |
|
else: |
|
self.hubert_model = self.hubert_model.float() |
|
|
|
return self.hubert_model.eval() |
|
|