haoheliu commited on
Commit
807c6f0
·
1 Parent(s): 1776a12

Update audioldm/pipeline.py

Browse files
Files changed (1) hide show
  1. audioldm/pipeline.py +18 -4
audioldm/pipeline.py CHANGED
@@ -30,7 +30,23 @@ def make_batch_for_text_to_audio(text, batchsize=1):
30
  )
31
  return batch
32
 
33
- def build_model(config=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if(torch.cuda.is_available()):
35
  device = torch.device("cuda:0")
36
  else:
@@ -40,7 +56,7 @@ def build_model(config=None):
40
  assert type(config) is str
41
  config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
42
  else:
43
- config = default_audioldm_config()
44
 
45
  # Use text as condition instead of using waveform during training
46
  config["model"]["params"]["device"] = device
@@ -49,8 +65,6 @@ def build_model(config=None):
49
  # No normalization here
50
  latent_diffusion = LatentDiffusion(**config["model"]["params"])
51
 
52
- resume_from_checkpoint = "./ckpt/ldm_trimmed.ckpt"
53
-
54
  checkpoint = torch.load(resume_from_checkpoint, map_location=device)
55
  latent_diffusion.load_state_dict(checkpoint["state_dict"])
56
 
 
30
  )
31
  return batch
32
 
33
+
34
+
35
+ def build_model(
36
+ ckpt_path=None,
37
+ config=None,
38
+ model_name="audioldm-s-full"
39
+ ):
40
+ print("Load AudioLDM: %s" % model_name)
41
+
42
+ resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
43
+
44
+ # if(ckpt_path is None):
45
+ # ckpt_path = get_metadata()[model_name]["path"]
46
+
47
+ # if(not os.path.exists(ckpt_path)):
48
+ # download_checkpoint(model_name)
49
+
50
  if(torch.cuda.is_available()):
51
  device = torch.device("cuda:0")
52
  else:
 
56
  assert type(config) is str
57
  config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
58
  else:
59
+ config = default_audioldm_config(model_name)
60
 
61
  # Use text as condition instead of using waveform during training
62
  config["model"]["params"]["device"] = device
 
65
  # No normalization here
66
  latent_diffusion = LatentDiffusion(**config["model"]["params"])
67
 
 
 
68
  checkpoint = torch.load(resume_from_checkpoint, map_location=device)
69
  latent_diffusion.load_state_dict(checkpoint["state_dict"])
70