Update app.py
Browse files
app.py
CHANGED
@@ -36,21 +36,15 @@ class model:
|
|
36 |
def __init__(self):
|
37 |
self.model = None
|
38 |
self.model_name = None
|
39 |
-
self.model_dict = {}
|
40 |
|
41 |
def gradio_demo(self, model_name, sequence_input, image):
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
if self.model_name != model_name:
|
44 |
self.model_name = model_name
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
50 |
-
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
51 |
-
self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
|
52 |
-
else:
|
53 |
-
model_ckpt_path, model_config_path = self.model_dict[self.model_name]
|
54 |
|
55 |
# Load model config and set ckpt_path if not provided in config
|
56 |
config = OmegaConf.load(model_config_path)
|
|
|
36 |
def __init__(self):
|
37 |
self.model = None
|
38 |
self.model_name = None
|
|
|
39 |
|
40 |
def gradio_demo(self, model_name, sequence_input, image):
|
41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
if self.model_name != model_name:
|
43 |
self.model_name = model_name
|
44 |
+
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
|
45 |
+
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
|
46 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
47 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# Load model config and set ckpt_path if not provided in config
|
50 |
config = OmegaConf.load(model_config_path)
|