File size: 3,171 Bytes
44189a1
 
c71b96e
 
44189a1
 
 
dc78df8
 
44189a1
 
c71b96e
44189a1
dc78df8
c71b96e
44189a1
dc78df8
 
44189a1
dc78df8
 
 
 
 
 
 
 
 
 
 
 
 
c71b96e
 
dc78df8
c71b96e
dc78df8
 
 
 
 
c71b96e
 
44189a1
c71b96e
44189a1
8c25de0
 
 
73b0806
8c25de0
 
73b0806
c71b96e
73b0806
 
 
8c25de0
73b0806
c71b96e
8c25de0
 
 
 
 
 
 
 
 
 
44189a1
8c25de0
 
 
 
 
 
 
 
 
 
 
 
44189a1
c71b96e
44189a1
c71b96e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext

check_min_version('0.28.0.dev0')

def load_models(task_name, device):
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
        model_d = 'jingheya/lotus-depth-d-v1-1'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'
        model_d = 'jingheya/lotus-normal-d-v1-0'

    dtype = torch.float16
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )
    pipe_d = LotusDPipeline.from_pretrained(
        model_d,
        torch_dtype=dtype,
    )
    pipe_g.to(device)
    pipe_d.to(device)
    pipe_g.set_progress_bar_config(disable=True)
    pipe_d.set_progress_bar_config(disable=True)
    logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
    return pipe_g, pipe_d

def infer_pipe(pipe, images_batch, task_name, seed, device):
    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(pipe.device.type)

    with torch.no_grad():
        with autocast_ctx:
            # Convert list of images to tensor
            images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch]
            test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
            test_images = test_images / 127.5 - 1.0
            test_images = test_images.to(device).type(torch.float16)

            # Ensure task_emb matches expected dimensions
            batch_size = test_images.shape[0]
            task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
            task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
            task_emb = task_emb.repeat(batch_size, 1)

            # Run inference
            preds = pipe(
                rgb_in=test_images,
                prompt='',
                num_inference_steps=1,
                generator=generator,
                output_type='np',
                timesteps=[999],
                task_emb=task_emb,
            ).images

            # Post-process predictions
            outputs = []
            if task_name == 'depth':
                for p in preds:
                    output_npy = p.mean(axis=-1)
                    output_color = colorize_depth_map(output_npy)
                    outputs.append(output_color)
            else:
                for p in preds:
                    output_npy = p
                    output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
                    outputs.append(output_color)

    return outputs

def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
    output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
    return output_d  # Only returning depth outputs for this application