File size: 8,741 Bytes
076275f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import gradio as gr
import os
import numpy as np
import trimesh as tm
from src.model import DinoV2
from src.shape_model import CSE
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import hashlib

def image_hash(image):
    """Generate a hash for an image."""
    image_bytes = image.tobytes()
    hash_function = hashlib.sha256()
    hash_function.update(image_bytes)
    return hash_function.hexdigest()


#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
models = {}
for class_name in ['bear', 'horse', 'elephant']:
    print(f'Loading model weights for {class_name}')
    models[class_name] = {
        'image_encoder': DinoV2(16),
        'cse': CSE(class_name=class_name, num_basis=64, device=device)
    }
    models[class_name]['image_encoder'].load_state_dict(torch.load(f'./models/weights/{class_name}.pth', map_location=device))
    models[class_name]['cse'].load_state_dict(torch.load(f'./models/weights/{class_name}_cse.pth', map_location=device))
    models[class_name]['cse'].functional_basis = torch.load(f'./models/weights/{class_name}_lbo.pth', map_location=device)
    
    models[class_name]['image_encoder'] = models[class_name]['image_encoder'].to(device)
    models[class_name]['cse'] = models[class_name]['cse'].to(device)
    models[class_name]['cse'].functional_basis = models[class_name]['cse'].functional_basis.to(device)
    models[class_name]['cse'].weight_matrix = models[class_name]['cse'].weight_matrix.to(device)

    models[class_name]['shape_feats'] = models[class_name]['cse']().to(device)

# Convert PIL image to a format your model expects (e.g., torch.Tensor)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

cached_features = {'bear': {}, 'horse': {}, 'elephant': {}}

text_description = """
# Demo for SHIC: Shape-Image Correspondences with no Keypoint Superivision (ECCV 2024)
Project website: https://www.robots.ox.ac.uk/~vgg/research/shic/

- **Step 1:** First select a class (now it defaults to 'bear') 
- **Step 2:** Upload an image of an animal from that class (or select one of the provided examples)
- **Step 3:** Click on the image (somewhere over the object) to see the image-to-shape correspondences. You can keep clicking on the input image to see new correspondences.

Notes: 
- You can click and drag to rotate the 3D shape
- Currently the demo supports bears, horses, and elephants. Other classes coming soon!
- Make sure you have selected the correct class for your image (It works cross-class too though!)
"""

example_images_dir = './gradio_example_images/'
example_images_names = [
    'bear1.png', 'bear2.png', 'bear3.png', 'winnie.png',
    'horse1.png', 'horse2.png', 'mylittlepony.png', 'ponyta.png',
    'elephant1.png', 'elephant2.png', 'dumbo.png', 'phanpy.png'
]
example_images = [os.path.join(example_images_dir, img) for img in example_images_names]
sphere_verts_ = torch.load(f'./models/weights/sphere_verts.pth', map_location=device)
sphere_faces_ = torch.load(f'./models/weights/sphere_faces.pth', map_location=device)
def center_crop(img):
    """
    Center crops an image to the target size of 224x224.
    """
    width, height = img.size   # Get dimensions
    # Calculate the target size for center cropping
    target_size = min(width, height)
    
    # Calculate the coordinates for center cropping
    left = (width - target_size) // 2
    top = (height - target_size) // 2
    right = left + target_size
    bottom = top + target_size
    
    # Perform center cropping
    cropped_img = img.crop((left, top, right, bottom))
    
    return cropped_img

def draw_point_on_image(image, x_, y_):
    """Draws a red dot on a copy of the image at the specified point."""
    # Make a copy of the image to avoid altering the original
    image_copy = image.copy()
    draw = ImageDraw.Draw(image_copy)
    x, y = x_, y_  # Adjust these based on the actual structure of `point`
    dot_radius = image.size[0] // 100
    # Draw a red dot
    draw.ellipse([(y-dot_radius, x-dot_radius), (y+dot_radius, x+dot_radius)], fill='red')
    
    return image_copy

def rotate_y(vertices, angle_degrees):
    angle_radians = np.radians(angle_degrees)
    rotation_matrix = np.array([
        [np.cos(angle_radians), 0, np.sin(angle_radians)],
        [0, 1, 0],
        [-np.sin(angle_radians), 0, np.cos(angle_radians)]
    ])
    
    # Assuming vertices is a numpy array of shape (N, 3)
    rotated_vertices = np.dot(vertices, rotation_matrix)
    return rotated_vertices

def make_final_mesh(verts, faces, similarities):
    vert_argmax = similarities.argmax(dim=1)
    vertex = verts[vert_argmax]
    color=[255, 0, 0]

    vertex_colors=similarities.transpose(1,0).cpu().detach().numpy()
    # to viridis color map
    vertex_colors = plt.cm.viridis(vertex_colors)[:, 0, :3]

    num_verts_so_far = len(verts)

    
    # Create a sphere mesh

    # Scale and translate the sphere to the desired location and size
    scale_dot = 0.015  # radius of the sphere
    translation = torch.tensor(vertex, device=device).unsqueeze(0)  # desired location

    verts_sphere = sphere_verts_ * scale_dot + translation  # scale and translate vertices
    faces_sphere = sphere_faces_ + num_verts_so_far  # faces are the same

    verts_rgb_sphere = torch.tensor([color], device=device).expand(verts_sphere.shape[0], -1)[None] / 255  # [1, N, 3]


    # verts and all sphere verts
    # concat np arrays verts + verts_sphere.cpu().numpy() (4936,3) (2562,3)
    all_verts = np.concatenate([verts, verts_sphere.cpu().numpy()], axis=0)
    all_faces = np.concatenate([faces, faces_sphere.cpu().numpy()], axis=0)

    all_textures = np.concatenate([vertex_colors, verts_rgb_sphere.cpu().numpy()[0]], axis=0)

    return tm.Trimesh(vertices=all_verts, faces=all_faces, vertex_colors=all_textures)

def process_mesh(image, class_name, x_, y_):
    x_, y_ = x_, y_
    h, w = image.size

    x = torch.tensor(x_ * 224 / w)
    y = torch.tensor(y_ * 224 / h)

    hashed_image = image_hash(image)
    if hashed_image in cached_features[class_name]:
        feats = cached_features[class_name][hashed_image]
    else:
        
        image_tensor = transform(image).unsqueeze(0)
        
        # Predict texture
        feats = models[class_name]['image_encoder'](image_tensor.to(device))

        cached_features[class_name][hashed_image] = feats

    # print('feats shape', feats.shape)

    sampled_feats = feats[:, :, x.long(), y.long()]
    similarities = torch.einsum('ik, lk -> il', sampled_feats, models[class_name]['shape_feats'])
    # normalize similarities
    similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min())

    faces =  models[class_name]['cse'].shape['faces'].cpu().numpy().copy()
    verts = models[class_name]['cse'].shape['verts'].cpu().numpy().copy()

    # rotate the shape 235 
    verts = rotate_y(verts, 145)

    mesh = make_final_mesh(verts, faces, similarities)
    # save as obj
    mesh_path = './mesh.obj'
    mesh.export(mesh_path)
    
    return mesh_path

def update_output(image, class_name, evt: gr.SelectData):
    if class_name is None:
        class_name = 'bear'
    # This function will be triggered when an image is clicked.
    # evt contains the click event data, including the coordinates.
    x_, y_ = evt.index[1], evt.index[0]
    modified_image = draw_point_on_image(image, x_, y_)
    mesh_path = process_mesh(image, class_name, x_, y_)
    return modified_image, mesh_path  # Replace with the actual model path
    

with gr.Blocks() as demo:
    # choose a class
    gr.Markdown(text_description)


    with gr.Row(variant="panel"):
        with gr.Column(scale=1):
            class_name = gr.Dropdown(choices=['bear', 'horse', 'elephant'], 
                               label="Select a class (defaults to 'bear')")
            input_img = gr.Image(label="Input", type="pil", width=256)
            gr.Examples(
                examples = example_images,
                inputs = [input_img],
                cache_examples=False,
                label='Feel free to use one of our provided examples!',
                examples_per_page=30
            )
        with gr.Column(scale=1):    
            output_img = gr.Image(label="Selected Point", interactive=False, height=512)
        with gr.Column(scale=1):
            output = gr.Model3D(label='Pixel to Vertex Similarities', height=512)
    


    input_img.select(update_output, [input_img, class_name], [output_img, output])

if __name__ == "__main__":
    demo.launch(share=True)