Emaad commited on
Commit
217e274
1 Parent(s): 120b678

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -12,15 +12,20 @@ class model:
12
  def __init__(self):
13
  self.model = None
14
  self.model_name = None
 
15
 
16
  def gradio_demo(self, model_name, sequence_input, image):
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  if self.model_name != model_name:
19
  self.model_name = model_name
20
- model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
21
- model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
22
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
23
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
 
 
 
 
24
 
25
  # Load model config and set ckpt_path if not provided in config
26
  config = OmegaConf.load(model_config_path)
 
12
  def __init__(self):
13
  self.model = None
14
  self.model_name = None
15
+ self.model_dict = {}
16
 
17
  def gradio_demo(self, model_name, sequence_input, image):
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  if self.model_name != model_name:
20
  self.model_name = model_name
21
+ if self.model_name not in self.model_dict.keys():
22
+ model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
23
+ model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
24
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
25
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
26
+ self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
27
+ else:
28
+ model_ckpt_path, model_config_path = self.model_dict[self.model_name]
29
 
30
  # Load model config and set ckpt_path if not provided in config
31
  config = OmegaConf.load(model_config_path)