cpu
Browse files- configs/demo.yaml +1 -0
- util.py +7 -6
configs/demo.yaml
CHANGED
@@ -24,6 +24,7 @@ dual_conditioner: False
|
|
24 |
steps: 50
|
25 |
init_step: 0
|
26 |
num_workers: 0
|
|
|
27 |
gpu: 0
|
28 |
max_iter: 100
|
29 |
|
|
|
24 |
steps: 50
|
25 |
init_step: 0
|
26 |
num_workers: 0
|
27 |
+
use_gpu: False
|
28 |
gpu: 0
|
29 |
max_iter: 100
|
30 |
|
util.py
CHANGED
@@ -32,18 +32,19 @@ SD_XL_BASE_RATIOS = {
|
|
32 |
"3.0": (1728, 576),
|
33 |
}
|
34 |
|
35 |
-
def init_model(
|
36 |
|
37 |
-
model_cfg = OmegaConf.load(
|
38 |
-
ckpt =
|
39 |
|
40 |
model = instantiate_from_config(model_cfg.model)
|
41 |
model.init_from_ckpt(ckpt)
|
42 |
|
43 |
-
if
|
44 |
model.train()
|
45 |
else:
|
46 |
-
|
|
|
47 |
model.eval()
|
48 |
model.freeze()
|
49 |
|
@@ -108,7 +109,7 @@ def deep_copy(batch):
|
|
108 |
def prepare_batch(cfgs, batch):
|
109 |
|
110 |
for key in batch:
|
111 |
-
if isinstance(batch[key], torch.Tensor):
|
112 |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
|
113 |
|
114 |
if not cfgs.dual_conditioner:
|
|
|
32 |
"3.0": (1728, 576),
|
33 |
}
|
34 |
|
35 |
+
def init_model(cfgs):
|
36 |
|
37 |
+
model_cfg = OmegaConf.load(cfgs.model_cfg_path)
|
38 |
+
ckpt = cfgs.load_ckpt_path
|
39 |
|
40 |
model = instantiate_from_config(model_cfg.model)
|
41 |
model.init_from_ckpt(ckpt)
|
42 |
|
43 |
+
if cfgs.type == "train":
|
44 |
model.train()
|
45 |
else:
|
46 |
+
if cfgs.use_gpu:
|
47 |
+
model.to(torch.device("cuda", index=cfgs.gpu))
|
48 |
model.eval()
|
49 |
model.freeze()
|
50 |
|
|
|
109 |
def prepare_batch(cfgs, batch):
|
110 |
|
111 |
for key in batch:
|
112 |
+
if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu:
|
113 |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
|
114 |
|
115 |
if not cfgs.dual_conditioner:
|