File size: 8,741 Bytes
076275f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import gradio as gr
import os
import numpy as np
import trimesh as tm
from src.model import DinoV2
from src.shape_model import CSE
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import hashlib
def image_hash(image):
"""Generate a hash for an image."""
image_bytes = image.tobytes()
hash_function = hashlib.sha256()
hash_function.update(image_bytes)
return hash_function.hexdigest()
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
models = {}
for class_name in ['bear', 'horse', 'elephant']:
print(f'Loading model weights for {class_name}')
models[class_name] = {
'image_encoder': DinoV2(16),
'cse': CSE(class_name=class_name, num_basis=64, device=device)
}
models[class_name]['image_encoder'].load_state_dict(torch.load(f'./models/weights/{class_name}.pth', map_location=device))
models[class_name]['cse'].load_state_dict(torch.load(f'./models/weights/{class_name}_cse.pth', map_location=device))
models[class_name]['cse'].functional_basis = torch.load(f'./models/weights/{class_name}_lbo.pth', map_location=device)
models[class_name]['image_encoder'] = models[class_name]['image_encoder'].to(device)
models[class_name]['cse'] = models[class_name]['cse'].to(device)
models[class_name]['cse'].functional_basis = models[class_name]['cse'].functional_basis.to(device)
models[class_name]['cse'].weight_matrix = models[class_name]['cse'].weight_matrix.to(device)
models[class_name]['shape_feats'] = models[class_name]['cse']().to(device)
# Convert PIL image to a format your model expects (e.g., torch.Tensor)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
cached_features = {'bear': {}, 'horse': {}, 'elephant': {}}
text_description = """
# Demo for SHIC: Shape-Image Correspondences with no Keypoint Superivision (ECCV 2024)
Project website: https://www.robots.ox.ac.uk/~vgg/research/shic/
- **Step 1:** First select a class (now it defaults to 'bear')
- **Step 2:** Upload an image of an animal from that class (or select one of the provided examples)
- **Step 3:** Click on the image (somewhere over the object) to see the image-to-shape correspondences. You can keep clicking on the input image to see new correspondences.
Notes:
- You can click and drag to rotate the 3D shape
- Currently the demo supports bears, horses, and elephants. Other classes coming soon!
- Make sure you have selected the correct class for your image (It works cross-class too though!)
"""
example_images_dir = './gradio_example_images/'
example_images_names = [
'bear1.png', 'bear2.png', 'bear3.png', 'winnie.png',
'horse1.png', 'horse2.png', 'mylittlepony.png', 'ponyta.png',
'elephant1.png', 'elephant2.png', 'dumbo.png', 'phanpy.png'
]
example_images = [os.path.join(example_images_dir, img) for img in example_images_names]
sphere_verts_ = torch.load(f'./models/weights/sphere_verts.pth', map_location=device)
sphere_faces_ = torch.load(f'./models/weights/sphere_faces.pth', map_location=device)
def center_crop(img):
"""
Center crops an image to the target size of 224x224.
"""
width, height = img.size # Get dimensions
# Calculate the target size for center cropping
target_size = min(width, height)
# Calculate the coordinates for center cropping
left = (width - target_size) // 2
top = (height - target_size) // 2
right = left + target_size
bottom = top + target_size
# Perform center cropping
cropped_img = img.crop((left, top, right, bottom))
return cropped_img
def draw_point_on_image(image, x_, y_):
"""Draws a red dot on a copy of the image at the specified point."""
# Make a copy of the image to avoid altering the original
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
x, y = x_, y_ # Adjust these based on the actual structure of `point`
dot_radius = image.size[0] // 100
# Draw a red dot
draw.ellipse([(y-dot_radius, x-dot_radius), (y+dot_radius, x+dot_radius)], fill='red')
return image_copy
def rotate_y(vertices, angle_degrees):
angle_radians = np.radians(angle_degrees)
rotation_matrix = np.array([
[np.cos(angle_radians), 0, np.sin(angle_radians)],
[0, 1, 0],
[-np.sin(angle_radians), 0, np.cos(angle_radians)]
])
# Assuming vertices is a numpy array of shape (N, 3)
rotated_vertices = np.dot(vertices, rotation_matrix)
return rotated_vertices
def make_final_mesh(verts, faces, similarities):
vert_argmax = similarities.argmax(dim=1)
vertex = verts[vert_argmax]
color=[255, 0, 0]
vertex_colors=similarities.transpose(1,0).cpu().detach().numpy()
# to viridis color map
vertex_colors = plt.cm.viridis(vertex_colors)[:, 0, :3]
num_verts_so_far = len(verts)
# Create a sphere mesh
# Scale and translate the sphere to the desired location and size
scale_dot = 0.015 # radius of the sphere
translation = torch.tensor(vertex, device=device).unsqueeze(0) # desired location
verts_sphere = sphere_verts_ * scale_dot + translation # scale and translate vertices
faces_sphere = sphere_faces_ + num_verts_so_far # faces are the same
verts_rgb_sphere = torch.tensor([color], device=device).expand(verts_sphere.shape[0], -1)[None] / 255 # [1, N, 3]
# verts and all sphere verts
# concat np arrays verts + verts_sphere.cpu().numpy() (4936,3) (2562,3)
all_verts = np.concatenate([verts, verts_sphere.cpu().numpy()], axis=0)
all_faces = np.concatenate([faces, faces_sphere.cpu().numpy()], axis=0)
all_textures = np.concatenate([vertex_colors, verts_rgb_sphere.cpu().numpy()[0]], axis=0)
return tm.Trimesh(vertices=all_verts, faces=all_faces, vertex_colors=all_textures)
def process_mesh(image, class_name, x_, y_):
x_, y_ = x_, y_
h, w = image.size
x = torch.tensor(x_ * 224 / w)
y = torch.tensor(y_ * 224 / h)
hashed_image = image_hash(image)
if hashed_image in cached_features[class_name]:
feats = cached_features[class_name][hashed_image]
else:
image_tensor = transform(image).unsqueeze(0)
# Predict texture
feats = models[class_name]['image_encoder'](image_tensor.to(device))
cached_features[class_name][hashed_image] = feats
# print('feats shape', feats.shape)
sampled_feats = feats[:, :, x.long(), y.long()]
similarities = torch.einsum('ik, lk -> il', sampled_feats, models[class_name]['shape_feats'])
# normalize similarities
similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min())
faces = models[class_name]['cse'].shape['faces'].cpu().numpy().copy()
verts = models[class_name]['cse'].shape['verts'].cpu().numpy().copy()
# rotate the shape 235
verts = rotate_y(verts, 145)
mesh = make_final_mesh(verts, faces, similarities)
# save as obj
mesh_path = './mesh.obj'
mesh.export(mesh_path)
return mesh_path
def update_output(image, class_name, evt: gr.SelectData):
if class_name is None:
class_name = 'bear'
# This function will be triggered when an image is clicked.
# evt contains the click event data, including the coordinates.
x_, y_ = evt.index[1], evt.index[0]
modified_image = draw_point_on_image(image, x_, y_)
mesh_path = process_mesh(image, class_name, x_, y_)
return modified_image, mesh_path # Replace with the actual model path
with gr.Blocks() as demo:
# choose a class
gr.Markdown(text_description)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
class_name = gr.Dropdown(choices=['bear', 'horse', 'elephant'],
label="Select a class (defaults to 'bear')")
input_img = gr.Image(label="Input", type="pil", width=256)
gr.Examples(
examples = example_images,
inputs = [input_img],
cache_examples=False,
label='Feel free to use one of our provided examples!',
examples_per_page=30
)
with gr.Column(scale=1):
output_img = gr.Image(label="Selected Point", interactive=False, height=512)
with gr.Column(scale=1):
output = gr.Model3D(label='Pixel to Vertex Similarities', height=512)
input_img.select(update_output, [input_img, class_name], [output_img, output])
if __name__ == "__main__":
demo.launch(share=True) |