Emaad commited on
Commit
d75590d
1 Parent(s): 9142bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -36,21 +36,15 @@ class model:
36
  def __init__(self):
37
  self.model = None
38
  self.model_name = None
39
- self.model_dict = {}
40
 
41
  def gradio_demo(self, model_name, sequence_input, image):
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  if self.model_name != model_name:
44
  self.model_name = model_name
45
- del self.model
46
- if self.model_name not in self.model_dict.keys():
47
- model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
48
- model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
49
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
50
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
51
- self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
52
- else:
53
- model_ckpt_path, model_config_path = self.model_dict[self.model_name]
54
 
55
  # Load model config and set ckpt_path if not provided in config
56
  config = OmegaConf.load(model_config_path)
 
36
  def __init__(self):
37
  self.model = None
38
  self.model_name = None
 
39
 
40
  def gradio_demo(self, model_name, sequence_input, image):
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  if self.model_name != model_name:
43
  self.model_name = model_name
44
+ model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
45
+ model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
46
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
47
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
 
 
 
 
 
48
 
49
  # Load model config and set ckpt_path if not provided in config
50
  config = OmegaConf.load(model_config_path)