Lotus_Depth / infer.py
ghostsInTheMachine's picture
Update infer.py
8c25de0 verified
raw
history blame
3.04 kB
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
check_min_version('0.28.0.dev0')
def load_models(task_name, device):
if task_name == 'depth':
model_g = 'jingheya/lotus-depth-g-v1-0'
model_d = 'jingheya/lotus-depth-d-v1-1'
else:
model_g = 'jingheya/lotus-normal-g-v1-0'
model_d = 'jingheya/lotus-normal-d-v1-0'
dtype = torch.float16
pipe_g = LotusGPipeline.from_pretrained(
model_g,
torch_dtype=dtype,
)
pipe_d = LotusDPipeline.from_pretrained(
model_d,
torch_dtype=dtype,
)
pipe_g.to(device)
pipe_d.to(device)
pipe_g.set_progress_bar_config(disable=True)
pipe_d.set_progress_bar_config(disable=True)
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
return pipe_g, pipe_d
def infer_pipe(pipe, images_batch, task_name, seed, device):
if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(pipe.device.type)
with torch.no_grad():
with autocast_ctx:
# Convert list of images to tensor
images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
test_images = test_images / 127.5 - 1.0
test_images = test_images.to(device)
task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
task_emb = task_emb.repeat(len(test_images), 1)
# Run inference
preds = pipe(
rgb_in=test_images,
prompt='',
num_inference_steps=1,
generator=generator,
output_type='np',
timesteps=[999],
task_emb=task_emb,
).images
# Post-process predictions
outputs = []
if task_name == 'depth':
for p in preds:
output_npy = p.mean(axis=-1)
output_color = colorize_depth_map(output_npy)
outputs.append(output_color)
else:
for p in preds:
output_npy = p
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
outputs.append(output_color)
return outputs
def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
return output_d # Only returning depth outputs for this application