File size: 3,875 Bytes
7a3883a
94f04b7
 
abe2204
94f04b7
 
 
 
 
 
 
afe246e
94f04b7
 
 
 
 
afe246e
94f04b7
 
 
 
 
 
 
 
 
afe246e
94f04b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe246e
 
 
 
7a3883a
afe246e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f04b7
 
 
 
 
 
abe2204
 
94f04b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Part of the code is from: fashn-ai/sapiens-body-part-segmentation
import os

import gradio as gr
import numpy as np
import spaces
import torch
from gradio.themes.utils import sizes
from PIL import Image
from torchvision import transforms
from utils.vis_utils import get_palette, visualize_mask_with_overlay
from config import SAPIENS_LITE_MODELS_PATH

if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

CHECKPOINTS_DIR = "checkpoints"

def load_model(checkpoint_name: str):
    checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
    model = torch.jit.load(checkpoint_path)
    model.eval()
    model.to("cuda")
    return model


#MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}

@torch.inference_mode()
def run_model(model, input_tensor, height, width):
    output = model(input_tensor)
    output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
    _, preds = torch.max(output, 1)
    return preds


transform_fn = transforms.Compose(
    [
        transforms.Resize((1024, 768)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def segment(image: Image.Image, model_name: str) -> Image.Image:
    input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
    model = MODELS[model_name]
    preds = run_model(model, input_tensor, height=image.height, width=image.width)
    mask = preds.squeeze(0).cpu().numpy()
    mask_image = Image.fromarray(mask.astype("uint8"))
    blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
    return blended_image


def update_model_choices(task):
    model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
    return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)

with gr.Blocks() as demo:
    gr.Markdown("# Sapiens Arena 🤸🏽‍♂️ - WIP devmode- Not yet available")
    with gr.Tabs():
        with gr.TabItem('Image'):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Input Image", type="pil", format="png")
                    select_task = gr.Radio(
                        ["Seg", "Pose", "Depth", "Normal"], 
                        label="Task", 
                        info="Choose the task to perfom",
                        choices=list(SAPIENS_LITE_MODELS_PATH.keys())
                    )
                    model_name = gr.Dropdown(
                        label="Model Version",
                        choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
                        value="0.3B",
                    )

                    # example_model = gr.Examples(
                    #     inputs=input_image,
                    #     examples_per_page=10,
                    #     examples=[
                    #         os.path.join(ASSETS_DIR, "examples", img)
                    #         for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
                    #     ],
                    # )
                with gr.Column():
                    result_image = gr.Image(label="Segmentation Result", format="png")
                    run_button = gr.Button("Run")

                    #gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
        
        with gr.TabItem('Video'):
            gr.Markdown("In construction")

    select_task.change(fn=update_model_choices, inputs=select_task, outputs=model_name)

    run_button.click(
        fn=segment,
        inputs=[input_image, model_name],
        outputs=[result_image],
    )


if __name__ == "__main__":
    demo.launch(share=False)