|
import logging |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from diffusers.utils import check_min_version |
|
from pipeline import LotusGPipeline, LotusDPipeline |
|
from utils.image_utils import colorize_depth_map |
|
from contextlib import nullcontext |
|
|
|
check_min_version('0.28.0.dev0') |
|
|
|
def load_models(task_name, device): |
|
if task_name == 'depth': |
|
model_g = 'jingheya/lotus-depth-g-v1-0' |
|
model_d = 'jingheya/lotus-depth-d-v1-1' |
|
else: |
|
model_g = 'jingheya/lotus-normal-g-v1-0' |
|
model_d = 'jingheya/lotus-normal-d-v1-0' |
|
|
|
dtype = torch.float16 |
|
pipe_g = LotusGPipeline.from_pretrained( |
|
model_g, |
|
torch_dtype=dtype, |
|
) |
|
pipe_d = LotusDPipeline.from_pretrained( |
|
model_d, |
|
torch_dtype=dtype, |
|
) |
|
pipe_g.to(device) |
|
pipe_d.to(device) |
|
pipe_g.set_progress_bar_config(disable=True) |
|
pipe_d.set_progress_bar_config(disable=True) |
|
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.") |
|
return pipe_g, pipe_d |
|
|
|
def infer_pipe(pipe, images_batch, task_name, seed, device): |
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
if torch.backends.mps.is_available(): |
|
autocast_ctx = nullcontext() |
|
else: |
|
autocast_ctx = torch.autocast(pipe.device.type) |
|
|
|
with torch.no_grad(): |
|
with autocast_ctx: |
|
|
|
images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch] |
|
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images]) |
|
test_images = test_images / 127.5 - 1.0 |
|
test_images = test_images.to(device).type(torch.float16) |
|
|
|
|
|
batch_size = test_images.shape[0] |
|
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0) |
|
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1) |
|
task_emb = task_emb.repeat(batch_size, 1) |
|
|
|
|
|
preds = pipe( |
|
rgb_in=test_images, |
|
prompt='', |
|
num_inference_steps=1, |
|
generator=generator, |
|
output_type='np', |
|
timesteps=[999], |
|
task_emb=task_emb, |
|
).images |
|
|
|
|
|
outputs = [] |
|
if task_name == 'depth': |
|
for p in preds: |
|
output_npy = p.mean(axis=-1) |
|
output_color = colorize_depth_map(output_npy) |
|
outputs.append(output_color) |
|
else: |
|
for p in preds: |
|
output_npy = p |
|
output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) |
|
outputs.append(output_color) |
|
|
|
return outputs |
|
|
|
def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d): |
|
output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device) |
|
return output_d |