import os import torch import logging import requests import zipfile from lib.models import ( SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono, ) from src.config import Config from src.vc_infer_pipeline import VC logging.getLogger("fairseq").setLevel(logging.WARNING) class ModelLoader: def __init__(self): self.model_root = os.path.join(os.getcwd(), "weights") self.config = Config() self.model_list = [ d for d in os.listdir(self.model_root) if os.path.isdir(os.path.join(self.model_root, d)) ] if len(self.model_list) == 0: raise ValueError("No model found in `weights` folder") self.model_name = "" self.model_list.sort() self.tgt_sr = None self.net_g = None self.vc = None self.version = None self.index_file = None self.if_f0 = None def _load_from_zip_url(self, url): response = requests.get(url) file_name = os.path.join( self.model_root, os.path.basename(url[: url.index(".zip") + 4]) ) model_name = os.path.basename(file_name).replace(".zip", "") print(f"Extraacting Model: {model_name}") if response.status_code == 200: with open(file_name, "wb") as file: file.write(response.content) with zipfile.ZipFile(file_name, "r") as zip_ref: zip_ref.extractall(os.path.join(self.model_root, model_name)) os.remove(file_name) else: print("Could not download model: {model_name}") return model_name def load(self, model_name): if "http" in model_name: model_name = self._load_from_zip_url(model_name) pth_files = [ os.path.join(self.model_root, model_name, f) for f in os.listdir(os.path.join(self.model_root, model_name)) if f.endswith(".pth") ] if len(pth_files) == 0: raise ValueError(f"No pth file found in {self.model_root}/{model_name}") self.model_name = model_name pth_path = pth_files[0] print(f"Loading {pth_path}, model: {model_name}") cpt = torch.load(pth_path, map_location="cpu") self.tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk self.if_f0 = cpt.get("f0", 1) self.version = cpt.get("version", "v1") if self.version == "v1": if self.if_f0 == 1: self.net_g = SynthesizerTrnMs256NSFsid( *cpt["config"], is_half=self.config.is_half ) else: self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) elif self.version == "v2": if self.if_f0 == 1: self.net_g = SynthesizerTrnMs768NSFsid( *cpt["config"], is_half=self.config.is_half ) else: self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) else: raise ValueError("Unknown version") del self.net_g.enc_q self.net_g.load_state_dict(cpt["weight"], strict=False) print("Model loaded") self.net_g.eval().to(self.config.device) if self.config.is_half: self.net_g = self.net_g.half() else: self.net_g = self.net_g.float() self.vc = VC(self.tgt_sr, self.config) index_files = [ os.path.join(self.model_root, model_name, f) for f in os.listdir(os.path.join(self.model_root, model_name)) if f.endswith(".index") ] if len(index_files) == 0: print("No index file found") self.index_file = "" else: self.index_file = index_files[0] print(f"Index file found: {self.index_file}") def load_hubert(self): from fairseq import checkpoint_utils models, _, _ = checkpoint_utils.load_model_ensemble_and_task( ["weights/hubert_base.pt"], suffix="", ) self.hubert_model = models[0] self.hubert_model = self.hubert_model.to(self.config.device) if self.config.is_half: self.hubert_model = self.hubert_model.half() else: self.hubert_model = self.hubert_model.float() return self.hubert_model.eval()