import argparse # Copyright (c) OpenMMLab. All rights reserved. import os import random os.system('python setup.py develop') import gradio as gr import numpy as np import torch from PIL import ImageDraw, Image from matplotlib import pyplot as plt from mmcv import Config from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms from demo import Resize_Pad from models import * import matplotlib matplotlib.use('agg') def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, skeleton, initial_proposals, prediction, radius=6): h, w, c = support_img.shape prediction = prediction[-1].cpu().numpy() * h query_img = (query_img - np.min(query_img)) / ( np.max(query_img) - np.min(query_img)) for id, (img, w, keypoint) in enumerate(zip([query_img], [query_w], [prediction])): f, axes = plt.subplots() plt.imshow(img) for k in range(keypoint.shape[0]): if w[k] > 0: kp = keypoint[k, :2] c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) patch = plt.Circle(kp, radius, color=c) axes.add_patch(patch) axes.text(kp[0], kp[1], k) plt.draw() for l, limb in enumerate(skeleton): kp = keypoint[:, :2] if l > len(COLORS) - 1: c = [x / 255 for x in random.sample(range(0, 255), 3)] else: c = [x / 255 for x in COLORS[l]] if w[limb[0]] > 0 and w[limb[1]] > 0: patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]], [kp[limb[0], 1], kp[limb[1], 1]], linewidth=6, color=c, alpha=0.6) axes.add_artist(patch) plt.axis('off') # command for hiding the axis. plt.subplots_adjust(0, 0, 1, 1, 0, 0) return plt COLORS = [ [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0] ] def process(query_img, state, cfg_path='configs/demo_b.py'): cfg = Config.fromfile(cfg_path) width, height, _ = state['original_support_image'].shape kp_src_np = np.array(state['kp_src']).copy().astype(np.float32) kp_src_np[:, 0] = kp_src_np[:,0] / (width // 4) * cfg.model.encoder_config.img_size kp_src_np[:, 1] = kp_src_np[:,1] / (height // 4) * cfg.model.encoder_config.img_size kp_src_np = np.flip(kp_src_np, 1).copy() kp_src_tensor = torch.tensor(kp_src_np).float() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) if len(state['skeleton']) == 0: state['skeleton'] = [(0, 0)] support_img = preprocess(state['original_support_image']).flip(0)[None] np_query = np.array(query_img)[:, :, ::-1].copy() q_img = preprocess(np_query).flip(0)[None] # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.cat( (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.cat( (torch.ones_like(kp_src_tensor), torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.ones_like( torch.tensor(target_weight_s).float()[None]) data = { 'img_s': [support_img], 'img_q': q_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [state['skeleton']], 'query_skeleton': state['skeleton'], 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src_tensor.mean(dim=0)], 'query_center': kp_src_tensor.mean(dim=0), 'sample_scale': [ kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0]], 'query_scale': kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, checkpoint_path, map_location='cpu') model.eval() with torch.no_grad(): outputs = model(**data) # visualize results vis_s_weight = target_weight_s[0] vis_q_weight = target_weight_s[0] vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) support_kp = kp_src_3d out = plot_results(vis_s_image, vis_q_image, support_kp, vis_s_weight, None, vis_q_weight, state['skeleton'], None, torch.tensor(outputs['points']).squeeze(0), ) return out, state with gr.Blocks() as demo: state = gr.State({ 'kp_src': [], 'skeleton': [], 'count': 0, 'color_idx': 0, 'prev_pt': None, 'prev_pt_idx': None, 'prev_clicked': None, 'original_support_image': None, }) gr.Markdown(''' # Pose Anything Demo We present a novel approach to category agnostic pose estimation that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities. ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything) ## Instructions 1. Upload an image of the object you want to pose on the **left** image. 2. Click on the **left** image to mark keypoints. 3. Click on the keypoints on the **right** image to mark limbs. 4. Upload an image of the object you want to pose to the query image ( **bottom**). 5. Click **Evaluate** to pose the query image. ''') with gr.Row(): support_img = gr.Image(label="Support Image", type="pil", info='Click to mark keypoints').style( height=400, width=400) posed_support = gr.Image(label="Posed Support Image", type="pil", interactive=False).style(height=400, width=400) with gr.Row(): query_img = gr.Image(label="Query Image", type="pil").style(height=400, width=400) with gr.Row(): eval_btn = gr.Button(value="Evaluate") with gr.Row(): output_img = gr.Plot(label="Output Image", height=400, width=400) def get_select_coords(kp_support, limb_support, state, evt: gr.SelectData, r=0.015): # global original_support_image # if len(kp_src) == 0: # original_support_image = np.array(kp_support)[:, :, # ::-1].copy() pixels_in_queue = set() pixels_in_queue.add((evt.index[1], evt.index[0])) while len(pixels_in_queue) > 0: pixel = pixels_in_queue.pop() if pixel[0] is not None and pixel[ 1] is not None and pixel not in state['kp_src']: state['kp_src'].append(pixel) else: print("Invalid pixel") if limb_support is None: canvas_limb = kp_support else: canvas_limb = limb_support canvas_kp = kp_support w, h = canvas_kp.size draw_pose = ImageDraw.Draw(canvas_kp) draw_limb = ImageDraw.Draw(canvas_limb) r = int(r * w) leftUpPoint = (pixel[1] - r, pixel[0] - r) rightDownPoint = (pixel[1] + r, pixel[0] + r) twoPointList = [leftUpPoint, rightDownPoint] draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) return canvas_kp, canvas_limb, state def get_limbs(kp_support, state, evt: gr.SelectData, r=0.02, width=0.02): curr_pixel = (evt.index[1], evt.index[0]) pixels_in_queue = set() pixels_in_queue.add((evt.index[1], evt.index[0])) canvas_kp = kp_support w, h = canvas_kp.size r = int(r * w) width = int(width * w) while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']: pixel = pixels_in_queue.pop() state['prev_clicked'] = pixel closest_point = min(state['kp_src'], key=lambda p: (p[0] - pixel[0]) ** 2 + (p[1] - pixel[1]) ** 2) closest_point_index = state['kp_src'].index(closest_point) draw_limb = ImageDraw.Draw(canvas_kp) if state['color_idx'] < len(COLORS): c = COLORS[state['color_idx']] else: c = random.choices(range(256), k=3) leftUpPoint = (closest_point[1] - r, closest_point[0] - r) rightDownPoint = (closest_point[1] + r, closest_point[0] + r) twoPointList = [leftUpPoint, rightDownPoint] draw_limb.ellipse(twoPointList, fill=tuple(c)) if state['count'] == 0: state['prev_pt'] = closest_point[1], closest_point[0] state['prev_pt_idx'] = closest_point_index state['count'] = state['count'] + 1 else: if state['prev_pt_idx'] != closest_point_index: # Create Line and add Limb draw_limb.line( [state['prev_pt'], (closest_point[1], closest_point[0])], fill=tuple(c), width=width) state['skeleton'].append((state['prev_pt_idx'], closest_point_index)) state['color_idx'] = state['color_idx'] + 1 else: draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) state['count'] = 0 return canvas_kp, state def set_qery(support_img, state): state['skeleton'].clear() state['kp_src'].clear() state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy() width, height = support_img.size support_img = support_img.resize((width // 4, width // 4), Image.Resampling.LANCZOS) return support_img, support_img, state support_img.select(get_select_coords, [support_img, posed_support, state], [support_img, posed_support, state]) support_img.upload(set_qery, inputs=[support_img, state], outputs=[support_img, posed_support, state]) posed_support.select(get_limbs, [posed_support, state], [posed_support, state]) eval_btn.click(fn=process, inputs=[query_img, state], outputs=[output_img, state]) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Pose Anything Demo') parser.add_argument('--checkpoint', help='checkpoint path', default='1shot-swin_graph_split1.pth') args = parser.parse_args() checkpoint_path = args.checkpoint print("Loading checkpoint from {}".format(checkpoint_path)) print(os.path.exists(checkpoint_path)) demo.launch()