import logging import torch import numpy as np from PIL import Image from diffusers.utils import check_min_version from pipeline import LotusGPipeline, LotusDPipeline from utils.image_utils import colorize_depth_map from contextlib import nullcontext import spaces # Import the spaces module for ZeroGPU check_min_version('0.28.0.dev0') # Global variables to store the models pipe_g = None pipe_d = None @spaces.GPU def load_models(task_name, device): global pipe_g, pipe_d # Use global variables to store the models if task_name == 'depth': model_g = 'jingheya/lotus-depth-g-v1-0' model_d = 'jingheya/lotus-depth-d-v1-1' else: model_g = 'jingheya/lotus-normal-g-v1-0' model_d = 'jingheya/lotus-normal-d-v1-0' dtype = torch.float16 pipe_g = LotusGPipeline.from_pretrained( model_g, torch_dtype=dtype, ) pipe_d = LotusDPipeline.from_pretrained( model_d, torch_dtype=dtype, ) pipe_g.to(device) pipe_d.to(device) pipe_g.set_progress_bar_config(disable=True) pipe_d.set_progress_bar_config(disable=True) logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.") def infer_pipe(pipe, image, task_name, seed, device): if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16) with torch.no_grad(): with autocast_ctx: # Convert image to tensor img = np.array(image.convert('RGB')).astype(np.float32) test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device).type(torch.float16) # Create task_emb task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1) # Run inference pred = pipe( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, output_type='np', timesteps=[999], task_emb=task_emb, ).images[0] # Post-process prediction if task_name == 'depth': output_npy = pred.mean(axis=-1) output_color = colorize_depth_map(output_npy) else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) return output_color def lotus(image, task_name, seed, device): global pipe_g, pipe_d # Access the global models output_d = infer_pipe(pipe_d, image, task_name, seed, device) return output_d # Only returning depth outputs for this application