File size: 3,008 Bytes
44189a1
 
c71b96e
 
44189a1
 
 
dc78df8
7376db6
dc78df8
44189a1
 
e4ce1e7
 
 
 
 
c71b96e
e4ce1e7
44189a1
dc78df8
c71b96e
44189a1
dc78df8
 
44189a1
dc78df8
e4ce1e7
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
693892f
dc78df8
 
 
 
 
c71b96e
 
44189a1
7376db6
44189a1
8c25de0
 
693892f
ca61b7a
693892f
 
ca61b7a
c71b96e
693892f
73b0806
8c25de0
c71b96e
8c25de0
693892f
 
8c25de0
 
 
 
 
 
693892f
44189a1
693892f
8c25de0
693892f
 
8c25de0
693892f
 
44189a1
693892f
44189a1
e4ce1e7
 
693892f
c71b96e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
import spaces  # Import the spaces module for ZeroGPU

check_min_version('0.28.0.dev0')

# Global variables to store the models
pipe_g = None
pipe_d = None

@spaces.GPU
def load_models(task_name, device):
    global pipe_g, pipe_d  # Use global variables to store the models
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
        model_d = 'jingheya/lotus-depth-d-v1-1'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'
        model_d = 'jingheya/lotus-normal-d-v1-0'

    dtype = torch.float16
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )
    pipe_d = LotusDPipeline.from_pretrained(
        model_d,
        torch_dtype=dtype,
    )
    pipe_g.to(device)
    pipe_d.to(device)
    pipe_g.set_progress_bar_config(disable=True)
    pipe_d.set_progress_bar_config(disable=True)
    logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")

def infer_pipe(pipe, image, task_name, seed, device):
    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16)

    with torch.no_grad():
        with autocast_ctx:
            # Convert image to tensor
            img = np.array(image.convert('RGB')).astype(np.float32)
            test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
            test_image = test_image / 127.5 - 1.0
            test_image = test_image.to(device).type(torch.float16)

            # Create task_emb
            task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
            task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)

            # Run inference
            pred = pipe(
                rgb_in=test_image,
                prompt='',
                num_inference_steps=1,
                generator=generator,
                output_type='np',
                timesteps=[999],
                task_emb=task_emb,
            ).images[0]

            # Post-process prediction
            if task_name == 'depth':
                output_npy = pred.mean(axis=-1)
                output_color = colorize_depth_map(output_npy)
            else:
                output_npy = pred
                output_color = Image.fromarray((output_npy * 255).astype(np.uint8))

    return output_color

def lotus(image, task_name, seed, device):
    global pipe_g, pipe_d  # Access the global models
    output_d = infer_pipe(pipe_d, image, task_name, seed, device)
    return output_d  # Only returning depth outputs for this application