import logging import torch import numpy as np from PIL import Image from diffusers.utils import check_min_version from pipeline import LotusGPipeline, LotusDPipeline from utils.image_utils import colorize_depth_map from contextlib import nullcontext check_min_version('0.28.0.dev0') def load_models(task_name, device): if task_name == 'depth': model_g = 'jingheya/lotus-depth-g-v1-0' model_d = 'jingheya/lotus-depth-d-v1-1' else: model_g = 'jingheya/lotus-normal-g-v1-0' model_d = 'jingheya/lotus-normal-d-v1-0' dtype = torch.float16 pipe_g = LotusGPipeline.from_pretrained( model_g, torch_dtype=dtype, ) pipe_d = LotusDPipeline.from_pretrained( model_d, torch_dtype=dtype, ) pipe_g.to(device) pipe_d.to(device) pipe_g.set_progress_bar_config(disable=True) pipe_d.set_progress_bar_config(disable=True) logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.") return pipe_g, pipe_d def infer_pipe(pipe, images_batch, task_name, seed, device): if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(pipe.device.type) with torch.no_grad(): with autocast_ctx: # Convert list of images to tensor images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch] test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images]) test_images = test_images / 127.5 - 1.0 test_images = test_images.to(device).type(torch.float16) # Ensure task_emb matches expected dimensions batch_size = test_images.shape[0] task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1) task_emb = task_emb.repeat(batch_size, 1) # Run inference preds = pipe( rgb_in=test_images, prompt='', num_inference_steps=1, generator=generator, output_type='np', timesteps=[999], task_emb=task_emb, ).images # Post-process predictions outputs = [] if task_name == 'depth': for p in preds: output_npy = p.mean(axis=-1) output_color = colorize_depth_map(output_npy) outputs.append(output_color) else: for p in preds: output_npy = p output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) outputs.append(output_color) return outputs def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d): output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device) return output_d # Only returning depth outputs for this application