File size: 2,872 Bytes
d4b233b
 
 
44189a1
 
c71b96e
 
44189a1
 
 
dc78df8
 
44189a1
 
c71b96e
44189a1
dc78df8
c71b96e
44189a1
dc78df8
 
44189a1
dc78df8
 
 
 
 
 
 
 
 
 
 
 
 
c71b96e
 
dc78df8
693892f
dc78df8
 
 
 
 
c71b96e
 
44189a1
c71b96e
44189a1
8c25de0
 
693892f
12b0993
693892f
 
12b0993
c71b96e
693892f
73b0806
8c25de0
c71b96e
8c25de0
693892f
 
8c25de0
 
 
 
 
 
693892f
44189a1
693892f
8c25de0
693892f
 
8c25de0
693892f
 
44189a1
693892f
44189a1
693892f
 
c71b96e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Import spaces first (even if not using @spaces.GPU here)
import spaces  # Import spaces to ensure it loads before torch

import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext

check_min_version('0.28.0.dev0')

def load_models(task_name, device):
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
        model_d = 'jingheya/lotus-depth-d-v1-1'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'
        model_d = 'jingheya/lotus-normal-d-v1-0'

    dtype = torch.float16
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )
    pipe_d = LotusDPipeline.from_pretrained(
        model_d,
        torch_dtype=dtype,
    )
    pipe_g.to(device)
    pipe_d.to(device)
    pipe_g.set_progress_bar_config(disable=True)
    pipe_d.set_progress_bar_config(disable=True)
    logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
    return pipe_g, pipe_d

def infer_pipe(pipe, image, task_name, seed, device):
    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(pipe.device.type)

    with torch.no_grad():
        with autocast_ctx:
            # Convert image to tensor
            img = np.array(image.convert('RGB')).astype(np.float16)
            test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
            test_image = test_image / 127.5 - 1.0
            test_image = test_image.to(device)

            # Create task_emb
            task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
            task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)

            # Run inference
            pred = pipe(
                rgb_in=test_image,
                prompt='',
                num_inference_steps=1,
                generator=generator,
                output_type='np',
                timesteps=[999],
                task_emb=task_emb,
            ).images[0]

            # Post-process prediction
            if task_name == 'depth':
                output_npy = pred.mean(axis=-1)
                output_color = colorize_depth_map(output_npy)
            else:
                output_npy = pred
                output_color = Image.fromarray((output_npy * 255).astype(np.uint8))

    return output_color

def lotus(image, task_name, seed, device, pipe_g, pipe_d):
    output_d = infer_pipe(pipe_d, image, task_name, seed, device)
    return output_d  # Only returning depth outputs for this application