File size: 8,829 Bytes
aea26c8
 
 
 
12c1e62
 
aea26c8
 
75ff838
aea26c8
 
 
 
 
 
 
 
 
 
 
 
 
cb6a364
aea26c8
 
 
 
 
 
 
 
 
 
 
6460264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be9e67d
aea26c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c58cf3
 
 
 
 
aea26c8
3c58cf3
 
 
 
aea26c8
3c58cf3
 
aea26c8
80d1fb2
 
6460264
 
80d1fb2
e20d68e
 
 
 
 
 
 
 
6460264
aea26c8
e20d68e
 
aea26c8
e20d68e
 
aea26c8
e20d68e
 
 
aea26c8
e20d68e
aea26c8
e20d68e
 
 
 
6460264
 
 
 
 
aea26c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed9dd3
d41e8b4
 
aea26c8
7ed9dd3
 
aea26c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6460264
aea26c8
 
 
 
5839edf
 
0fdb354
aea26c8
 
 
 
 
 
 
 
be9e67d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os 
import sys 
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
# os.system('pip install /home/user/app/pyrender')
# sys.path.append('/home/user/app/pyrender')

import gradio as gr
import spaces
import cv2 
import numpy as np 
import torch 
from ultralytics import YOLO
from pathlib import Path
import argparse
import json
from typing import Dict, Optional

from wilor.models import WiLoR, load_wilor
from wilor.utils import recursive_to
from wilor.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
from wilor.utils.renderer import Renderer, cam_crop_to_full
device = torch.device('cpu') if torch.cuda.is_available() else torch.device('cuda')

LIGHT_PURPLE=(0.25098039,  0.274117647,  0.65882353)

model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
# Setup the renderer
renderer = Renderer(model_cfg, faces=model.mano.faces)
model = model.to(device)
model.eval()

detector = YOLO('./pretrained_models/detector.pt').to(device)

def render_reconstruction(image, conf, IoU_threshold=0.5): 
    input_img, num_dets, reconstructions = run_wilow_model(image, conf, IoU_threshold=0.5)
    if num_dets> 0: 
    # Render front view
    
        misc_args = dict(
            mesh_base_color=LIGHT_PURPLE,
            scene_bg_color=(1, 1, 1),
            focal_length=reconstructions['focal'],
        )

        cam_view = renderer.render_rgba_multiple(reconstructions['verts'], 
                                                 cam_t=reconstructions['cam_t'], 
                                                 render_res=reconstructions['img_size'], 
                                                 is_right=reconstructions['right'], **misc_args)

        # Overlay image
        
        input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
        input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]

        return input_img_overlay, f'{num_dets} hands detected'  
    else: 
        return input_img, f'{num_dets} hands detected' 

@spaces.GPU()
def run_wilow_model(image, conf, IoU_threshold=0.5):
    img_cv2 = image[...,::-1]
    img_vis = image.copy()
    
    detections = detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
    
    bboxes    = []
    is_right  = []
    for det in detections: 
        Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
        Conf = det.boxes.conf.data.cpu().detach()[0].numpy().reshape(-1).astype(np.float16)
        Side = det.boxes.cls.data.cpu().detach()
        #Bbox[:2] -= np.int32(0.1 * Bbox[:2])
        #Bbox[2:] += np.int32(0.1 * Bbox[ 2:])
        is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
        bboxes.append(Bbox[:4].tolist())
        
        color = (255*0.208, 255*0.647 ,255*0.603 ) if Side==0. else (255*1, 255*0.78039, 255*0.2353)
        label = f'L - {Conf[0]:.3f}' if Side==0 else f'R - {Conf[0]:.3f}'

        cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1])), (int(Bbox[2]), int(Bbox[3])), color , 3) 
        (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
        cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1]) - 20), (int(Bbox[0]) + w, int(Bbox[1])), color, -1)
        cv2.putText(img_vis, label, (int(Bbox[0]), int(Bbox[1]) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2)
          
    if len(bboxes) != 0: 
        boxes = np.stack(bboxes)
        right = np.stack(is_right)
        dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=2.0 )
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)
    
        all_verts = []
        all_cam_t = []
        all_right = []
        all_joints= []
    
        for batch in dataloader: 
            batch = recursive_to(batch, device)
    
            with torch.no_grad():
                out = model(batch) 
                
            print('CUDA AVAILABLE', torch.cuda.is_available())
            print(out['pred_vertices'])
            multiplier    = (2*batch['right']-1)
            pred_cam      = out['pred_cam']
            pred_cam[:,1] = multiplier*pred_cam[:,1]
            box_center    = batch["box_center"].float()
            box_size      = batch["box_size"].float()
            img_size      = batch["img_size"].float()
            scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
            pred_cam_t_full     = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy()

            
            batch_size = batch['img'].shape[0]
            for n in range(batch_size):
                
                verts  = out['pred_vertices'][n].detach().cpu().numpy()
                joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
                
                is_right = batch['right'][n].cpu().numpy()
                verts[:,0] = (2*is_right-1)*verts[:,0]
                joints[:,0] = (2*is_right-1)*joints[:,0]
                
                cam_t = pred_cam_t_full[n]
                
                all_verts.append(verts)
                all_cam_t.append(cam_t)
                all_right.append(is_right)
                all_joints.append(joints)

        reconstructions = {'verts': all_verts, 'cam_t': all_cam_t, 'right': all_right, 'img_size': img_size[n], 'focal': scaled_focal_length}
        return img_vis.astype(np.float32)/255.0, len(detections), reconstructions
    else: 
        return img_vis.astype(np.float32)/255.0, len(detections), None       



header = ('''
<div class="embed_hidden" style="text-align: center;">
    <h1> <b>WiLoR</b>: End-to-end 3D hand localization and reconstruction in-the-wild</h1>
    <h3>
        <a href="https://rolpotamias.github.io" target="_blank" rel="noopener noreferrer">Rolandos Alexandros Potamias</a><sup>1</sup>,
        <a href="" target="_blank" rel="noopener noreferrer">Jinglei Zhang</a><sup>2</sup>,
        <br>
        <a href="https://jiankangdeng.github.io/" target="_blank" rel="noopener noreferrer">Jiankang Deng</a><sup>1</sup>,
        <a href="https://wp.doc.ic.ac.uk/szafeiri/" target="_blank" rel="noopener noreferrer">Stefanos Zafeiriou</a><sup>1</sup>
    </h3>
    <h3>
        <sup>1</sup>Imperial College London;
        <sup>2</sup>Shanghai Jiao Tong University
    </h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/abs/2409.12259'><img src='https://img.shields.io/badge/Arxiv-2409.12259-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> 
<a href='https://rolpotamias.github.io/pdfs/WiLoR.pdf'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a> 
<a href='https://rolpotamias.github.io/WiLoR/'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> 
<a href='https://github.com/rolpotamias/WiLoR'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a> 
<a href='https://colab.research.google.com/drive/1bNnYFECmJbbvCNZAKtQcxJGxf0DZppsB?usp=sharing'><img src='https://colab.research.google.com/assets/colab-badge.svg'></a>

''')


with gr.Blocks(title="WiLoR: End-to-end 3D hand localization and reconstruction in-the-wild", css=".gradio-container") as demo:

    gr.Markdown(header)

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input image", type="numpy")
            threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
            #nms = gr.Slider(value=0.5, minimum=0.05, maximum=0.95, step=0.05, label='IoU NMS Threshold')
            submit = gr.Button("Submit", variant="primary")
        
        
        with gr.Column():
            reconstruction = gr.Image(label="Reconstructions", type="numpy")
            hands_detected = gr.Textbox(label="Hands Detected")
    
        submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])

    with gr.Row():
        
        example_images = gr.Examples([
            ['/home/user/app/assets/test6.jpg'], 
            ['/home/user/app/assets/test7.jpg'],
            ['/home/user/app/assets/test8.jpg'],            
            ['/home/user/app/assets/test1.jpg'], 
            ['/home/user/app/assets/test2.png'], 
            ['/home/user/app/assets/test3.jpg'], 
            ['/home/user/app/assets/test4.jpg'],
            ['/home/user/app/assets/test5.jpeg'] 
            ], 
            inputs=input_image)
    
demo.launch(debug=True)