Spaces:
Runtime error
Runtime error
Commit
·
d53d73c
1
Parent(s):
d9114d9
update
Browse files
app.py
CHANGED
|
@@ -19,6 +19,13 @@ from funcs import (
|
|
| 19 |
get_latent_z,
|
| 20 |
save_videos
|
| 21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def download_model():
|
| 24 |
REPO_ID = 'Doubiiu/DynamiCrafter_1024'
|
|
@@ -43,7 +50,7 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
|
|
| 43 |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
|
| 44 |
model = load_model_checkpoint(model, ckpt_path)
|
| 45 |
model.eval()
|
| 46 |
-
model = model.
|
| 47 |
save_fps = 8
|
| 48 |
|
| 49 |
seed_everything(seed)
|
|
@@ -51,7 +58,10 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
|
|
| 51 |
transforms.Resize(min(resolution)),
|
| 52 |
transforms.CenterCrop(resolution),
|
| 53 |
])
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
| 56 |
start = time.time()
|
| 57 |
if steps > 60:
|
|
@@ -154,4 +164,4 @@ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
|
|
| 154 |
fn = infer
|
| 155 |
)
|
| 156 |
|
| 157 |
-
dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
|
|
|
|
| 19 |
get_latent_z,
|
| 20 |
save_videos
|
| 21 |
)
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
device = "cuda"
|
| 24 |
+
elif torch.backends.mps.is_available():
|
| 25 |
+
device = "mps"
|
| 26 |
+
else:
|
| 27 |
+
device = "cpu"
|
| 28 |
+
|
| 29 |
|
| 30 |
def download_model():
|
| 31 |
REPO_ID = 'Doubiiu/DynamiCrafter_1024'
|
|
|
|
| 50 |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
|
| 51 |
model = load_model_checkpoint(model, ckpt_path)
|
| 52 |
model.eval()
|
| 53 |
+
model = model.to(device)
|
| 54 |
save_fps = 8
|
| 55 |
|
| 56 |
seed_everything(seed)
|
|
|
|
| 58 |
transforms.Resize(min(resolution)),
|
| 59 |
transforms.CenterCrop(resolution),
|
| 60 |
])
|
| 61 |
+
if device == "cuda":
|
| 62 |
+
torch.cuda.empty_cache()
|
| 63 |
+
elif device == "mps":
|
| 64 |
+
torch.mps.empty_cache()
|
| 65 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
| 66 |
start = time.time()
|
| 67 |
if steps > 60:
|
|
|
|
| 164 |
fn = infer
|
| 165 |
)
|
| 166 |
|
| 167 |
+
dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
|