Update
Browse files
model.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import gc
|
2 |
import tempfile
|
3 |
|
4 |
import numpy as np
|
@@ -70,17 +69,15 @@ class Model:
|
|
70 |
'cuda' if torch.cuda.is_available() else 'cpu')
|
71 |
self.xm = load_model('transmitter', device=self.device)
|
72 |
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
73 |
-
self.
|
74 |
-
self.
|
75 |
|
76 |
def load_model(self, model_name: str) -> None:
|
77 |
assert model_name in ['text300M', 'image300M']
|
78 |
-
if model_name == self.
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
gc.collect()
|
83 |
-
torch.cuda.empty_cache()
|
84 |
|
85 |
def to_glb(self, latent: torch.Tensor) -> str:
|
86 |
ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
|
@@ -109,7 +106,7 @@ class Model:
|
|
109 |
|
110 |
latents = sample_latents(
|
111 |
batch_size=1,
|
112 |
-
model=self.
|
113 |
diffusion=self.diffusion,
|
114 |
guidance_scale=guidance_scale,
|
115 |
model_kwargs=dict(texts=[prompt]),
|
@@ -135,7 +132,7 @@ class Model:
|
|
135 |
image = load_image(image_path)
|
136 |
latents = sample_latents(
|
137 |
batch_size=1,
|
138 |
-
model=self.
|
139 |
diffusion=self.diffusion,
|
140 |
guidance_scale=guidance_scale,
|
141 |
model_kwargs=dict(images=[image]),
|
|
|
|
|
1 |
import tempfile
|
2 |
|
3 |
import numpy as np
|
|
|
69 |
'cuda' if torch.cuda.is_available() else 'cpu')
|
70 |
self.xm = load_model('transmitter', device=self.device)
|
71 |
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
72 |
+
self.model_text = None
|
73 |
+
self.model_image = None
|
74 |
|
75 |
def load_model(self, model_name: str) -> None:
|
76 |
assert model_name in ['text300M', 'image300M']
|
77 |
+
if model_name == 'text300M' and self.model_text is None:
|
78 |
+
self.model_text = load_model(model_name, device=self.device)
|
79 |
+
elif model_name == 'image300M' and self.model_image is None:
|
80 |
+
self.model_image = load_model(model_name, device=self.device)
|
|
|
|
|
81 |
|
82 |
def to_glb(self, latent: torch.Tensor) -> str:
|
83 |
ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
|
|
|
106 |
|
107 |
latents = sample_latents(
|
108 |
batch_size=1,
|
109 |
+
model=self.model_text,
|
110 |
diffusion=self.diffusion,
|
111 |
guidance_scale=guidance_scale,
|
112 |
model_kwargs=dict(texts=[prompt]),
|
|
|
132 |
image = load_image(image_path)
|
133 |
latents = sample_latents(
|
134 |
batch_size=1,
|
135 |
+
model=self.model_image,
|
136 |
diffusion=self.diffusion,
|
137 |
guidance_scale=guidance_scale,
|
138 |
model_kwargs=dict(images=[image]),
|