File size: 3,003 Bytes
44189a1
 
c71b96e
 
44189a1
 
 
dc78df8
7376db6
dc78df8
44189a1
 
c71b96e
44189a1
dc78df8
c71b96e
44189a1
dc78df8
 
44189a1
dc78df8
7376db6
 
dc78df8
7376db6
693892f
dc78df8
 
 
 
 
c71b96e
 
44189a1
7376db6
44189a1
8c25de0
 
693892f
ca61b7a
693892f
 
ca61b7a
c71b96e
693892f
73b0806
8c25de0
c71b96e
8c25de0
693892f
 
8c25de0
 
 
 
 
 
693892f
44189a1
693892f
8c25de0
693892f
 
8c25de0
693892f
 
44189a1
693892f
44189a1
7376db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693892f
c71b96e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
import spaces  # Import the spaces module for ZeroGPU

check_min_version('0.28.0.dev0')

def load_models(task_name, device):
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
        model_d = 'jingheya/lotus-depth-d-v1-1'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'
        model_d = 'jingheya/lotus-normal-d-v1-0'

    dtype = torch.float16
    # Models will be loaded inside the GPU-decorated function
    return model_g, model_d, dtype

@spaces.GPU
def infer_pipe(pipe, image, task_name, seed, device):
    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16)

    with torch.no_grad():
        with autocast_ctx:
            # Convert image to tensor
            img = np.array(image.convert('RGB')).astype(np.float32)
            test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
            test_image = test_image / 127.5 - 1.0
            test_image = test_image.to(device).type(torch.float16)

            # Create task_emb
            task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
            task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)

            # Run inference
            pred = pipe(
                rgb_in=test_image,
                prompt='',
                num_inference_steps=1,
                generator=generator,
                output_type='np',
                timesteps=[999],
                task_emb=task_emb,
            ).images[0]

            # Post-process prediction
            if task_name == 'depth':
                output_npy = pred.mean(axis=-1)
                output_color = colorize_depth_map(output_npy)
            else:
                output_npy = pred
                output_color = Image.fromarray((output_npy * 255).astype(np.uint8))

    return output_color

@spaces.GPU
def lotus(image, task_name, seed, device, model_g, model_d, dtype):
    # Load models inside the GPU-decorated function
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )
    pipe_d = LotusDPipeline.from_pretrained(
        model_d,
        torch_dtype=dtype,
    )
    pipe_g.to(device)
    pipe_d.to(device)
    pipe_g.set_progress_bar_config(disable=True)
    pipe_d.set_progress_bar_config(disable=True)
    logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")

    output_d = infer_pipe(pipe_d, image, task_name, seed, device)
    return output_d  # Only returning depth outputs for this application