mrfakename commited on
Commit
96269dd
·
verified ·
1 Parent(s): a8a1e8e

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

model/trainer.py CHANGED
@@ -140,7 +140,7 @@ class Trainer:
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
 
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
model/utils.py CHANGED
@@ -509,7 +509,7 @@ def run_sim(args):
509
  device = f"cuda:{rank}"
510
 
511
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
- state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
513
  model.load_state_dict(state_dict['model'], strict=False)
514
 
515
  use_gpu=True if torch.cuda.is_available() else False
@@ -565,7 +565,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
565
  from safetensors.torch import load_file
566
  checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
- checkpoint = torch.load(ckpt_path, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
 
509
  device = f"cuda:{rank}"
510
 
511
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
  model.load_state_dict(state_dict['model'], strict=False)
514
 
515
  use_gpu=True if torch.cuda.is_available() else False
 
565
  from safetensors.torch import load_file
566
  checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
scripts/eval_infer_batch.py CHANGED
@@ -127,7 +127,7 @@ local = False
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
 
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
speech_edit.py CHANGED
@@ -85,8 +85,9 @@ local = False
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
  vocos.load_state_dict(state_dict)
 
90
  vocos.eval()
91
  else:
92
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
89
  vocos.load_state_dict(state_dict)
90
+
91
  vocos.eval()
92
  else:
93
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")