hugo flores garcia commited on
Commit
08c78c6
1 Parent(s): 05d43c6
app.py CHANGED
@@ -105,6 +105,7 @@ def _vamp(
105
  codes, mask,
106
  batch_size=1 if api else 1,
107
  feedback_steps=1,
 
108
  time_stretch_factor=stretch_factor,
109
  return_mask=True,
110
  temperature=sampletemp,
 
105
  codes, mask,
106
  batch_size=1 if api else 1,
107
  feedback_steps=1,
108
+ _sampling_steps=12 if sig.duration <6.0 else 24,
109
  time_stretch_factor=stretch_factor,
110
  return_mask=True,
111
  temperature=sampletemp,
vampnet/__init__.py CHANGED
@@ -55,8 +55,9 @@ def download_finetuned(name):
55
  filenames = ["coarse.pth", "c2f.pth"]
56
  paths = []
57
  for filename in filenames:
58
- path = f"{MODELS_DIR}/{name}/loras/{filename}"
59
  if not Path(path).exists():
 
60
  path = hf_hub_download(
61
  repo_id=repo_id,
62
  filename=filename,
 
55
  filenames = ["coarse.pth", "c2f.pth"]
56
  paths = []
57
  for filename in filenames:
58
+ path = f"{MODELS_DIR}/loras/{name}/{filename}"
59
  if not Path(path).exists():
60
+ print(f"{path} does not exist, downloading")
61
  path = hf_hub_download(
62
  repo_id=repo_id,
63
  filename=filename,
vampnet/interface.py CHANGED
@@ -537,7 +537,7 @@ class Interface(torch.nn.Module):
537
  zv,
538
  mask=mask,
539
  typical_filtering=True,
540
- _sampling_steps=[2],
541
  return_mask=True
542
  )
543
  mask_z = torch.cat(
 
537
  zv,
538
  mask=mask,
539
  typical_filtering=True,
540
+ _sampling_steps=2,
541
  return_mask=True
542
  )
543
  mask_z = torch.cat(
vampnet/modules/transformer.py CHANGED
@@ -595,13 +595,13 @@ class VampNet(at.ml.BaseModel):
595
  self,
596
  codec,
597
  time_steps: int = 300,
598
- _sampling_steps: List[int] = [12],
599
  start_tokens: Optional[torch.Tensor] = None,
600
  temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
603
  typical_filtering=True,
604
- typical_mass=0.2,
605
  typical_min_tokens=64,
606
  top_p=None,
607
  seed: int = None,
@@ -613,7 +613,7 @@ class VampNet(at.ml.BaseModel):
613
  ):
614
  if seed is not None:
615
  at.util.seed(seed)
616
- sampling_steps = sum(_sampling_steps)
617
  logging.debug(f"beginning generation with {sampling_steps} steps")
618
 
619
  #####################
 
595
  self,
596
  codec,
597
  time_steps: int = 300,
598
+ _sampling_steps: int = 12,
599
  start_tokens: Optional[torch.Tensor] = None,
600
  temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
603
  typical_filtering=True,
604
+ typical_mass=0.15,
605
  typical_min_tokens=64,
606
  top_p=None,
607
  seed: int = None,
 
613
  ):
614
  if seed is not None:
615
  at.util.seed(seed)
616
+ sampling_steps = _sampling_steps
617
  logging.debug(f"beginning generation with {sampling_steps} steps")
618
 
619
  #####################