Spaces:
Paused
Paused
Update marigold_depth_estimation_lcm.py
Browse files
marigold_depth_estimation_lcm.py
CHANGED
@@ -283,6 +283,7 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
|
|
283 |
"""
|
284 |
Encode text embedding for empty prompt.
|
285 |
"""
|
|
|
286 |
prompt = ""
|
287 |
text_inputs = self.tokenizer(
|
288 |
prompt,
|
@@ -291,8 +292,11 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
|
|
291 |
truncation=True,
|
292 |
return_tensors="pt",
|
293 |
)
|
|
|
294 |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
|
|
295 |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
|
|
296 |
|
297 |
@torch.no_grad()
|
298 |
def single_infer(
|
@@ -358,7 +362,10 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
|
|
358 |
|
359 |
# Batched empty text embedding
|
360 |
if self.empty_text_embed is None:
|
|
|
361 |
self._encode_empty_text()
|
|
|
|
|
362 |
batch_empty_text_embed = self.empty_text_embed.repeat(
|
363 |
(rgb_latent.shape[0], 1, 1)
|
364 |
) # [B, 2, 1024]
|
|
|
283 |
"""
|
284 |
Encode text embedding for empty prompt.
|
285 |
"""
|
286 |
+
print("_encode_empty_text")
|
287 |
prompt = ""
|
288 |
text_inputs = self.tokenizer(
|
289 |
prompt,
|
|
|
292 |
truncation=True,
|
293 |
return_tensors="pt",
|
294 |
)
|
295 |
+
print(f"{self.text_encoder.device=}")
|
296 |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
297 |
+
print(f"{text_input_ids.device=}")
|
298 |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
299 |
+
print(f"{self.empty_text_embed.device=}", f"{self.empty_text_embed.dtype=}")
|
300 |
|
301 |
@torch.no_grad()
|
302 |
def single_infer(
|
|
|
362 |
|
363 |
# Batched empty text embedding
|
364 |
if self.empty_text_embed is None:
|
365 |
+
print("self.empty_text_embed is None")
|
366 |
self._encode_empty_text()
|
367 |
+
else:
|
368 |
+
print("self.empty_text_embed is not None")
|
369 |
batch_empty_text_embed = self.empty_text_embed.repeat(
|
370 |
(rgb_latent.shape[0], 1, 1)
|
371 |
) # [B, 2, 1024]
|