Lotus_Depth / infer.py
ghostsInTheMachine's picture
Update infer.py
693892f verified
raw
history blame
2.77 kB
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
check_min_version('0.28.0.dev0')
def load_models(task_name, device):
if task_name == 'depth':
model_g = 'jingheya/lotus-depth-g-v1-0'
model_d = 'jingheya/lotus-depth-d-v1-1'
else:
model_g = 'jingheya/lotus-normal-g-v1-0'
model_d = 'jingheya/lotus-normal-d-v1-0'
dtype = torch.float16
pipe_g = LotusGPipeline.from_pretrained(
model_g,
torch_dtype=dtype,
)
pipe_d = LotusDPipeline.from_pretrained(
model_d,
torch_dtype=dtype,
)
pipe_g.to(device)
pipe_d.to(device)
pipe_g.set_progress_bar_config(disable=True)
pipe_d.set_progress_bar_config(disable=True)
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
return pipe_g, pipe_d
def infer_pipe(pipe, image, task_name, seed, device):
if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(pipe.device.type)
with torch.no_grad():
with autocast_ctx:
# Convert image to tensor
img = np.array(image.convert('RGB')).astype(np.float32)
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
test_image = test_image / 127.5 - 1.0
test_image = test_image.to(device).type(torch.float16)
# Create task_emb
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
# Run inference
pred = pipe(
rgb_in=test_image,
prompt='',
num_inference_steps=1,
generator=generator,
output_type='np',
timesteps=[999],
task_emb=task_emb,
).images[0]
# Post-process prediction
if task_name == 'depth':
output_npy = pred.mean(axis=-1)
output_color = colorize_depth_map(output_npy)
else:
output_npy = pred
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
return output_color
def lotus(image, task_name, seed, device, pipe_g, pipe_d):
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
return output_d # Only returning depth outputs for this application