Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from dataclasses import dataclass | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import mediapipe as mp | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
import vqvae | |
import vit | |
from typing import Literal | |
from diffusion import create_diffusion | |
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity | |
from segment_hoi import init_sam | |
from io import BytesIO | |
from PIL import Image | |
import random | |
from copy import deepcopy | |
from typing import Optional | |
import requests | |
from huggingface_hub import hf_hub_download | |
import spaces | |
MAX_N = 6 | |
FIX_MAX_N = 6 | |
placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB) | |
NEW_MODEL = True | |
MODEL_EPOCH = 6 | |
REF_POSE_MASK = True | |
def set_seed(seed): | |
seed = int(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
# if torch.cuda.is_available(): | |
device = "cuda" | |
# else: | |
# device = "cpu" | |
def remove_prefix(text, prefix): | |
if text.startswith(prefix): | |
return text[len(prefix) :] | |
return text | |
def unnormalize(x): | |
return (((x + 1) / 2) * 255).astype(np.uint8) | |
def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21): | |
# Define the connections between joints for drawing lines and their corresponding colors | |
connections = [ | |
((0, 1), "red"), | |
((1, 2), "green"), | |
((2, 3), "blue"), | |
((3, 4), "purple"), | |
((0, 5), "orange"), | |
((5, 6), "pink"), | |
((6, 7), "brown"), | |
((7, 8), "cyan"), | |
((0, 9), "yellow"), | |
((9, 10), "magenta"), | |
((10, 11), "lime"), | |
((11, 12), "indigo"), | |
((0, 13), "olive"), | |
((13, 14), "teal"), | |
((14, 15), "navy"), | |
((15, 16), "gray"), | |
((0, 17), "lavender"), | |
((17, 18), "silver"), | |
((18, 19), "maroon"), | |
((19, 20), "fuchsia"), | |
] | |
H, W, C = img.shape | |
# Create a figure and axis | |
plt.figure() | |
ax = plt.gca() | |
# Plot joints as points | |
ax.imshow(img) | |
start_is = [] | |
if "right" in side: | |
start_is.append(0) | |
if "left" in side: | |
start_is.append(21) | |
for start_i in start_is: | |
joints = all_joints[start_i : start_i + n_avail_joints] | |
if len(joints) == 1: | |
ax.scatter(joints[0][0], joints[0][1], color="red", s=10) | |
else: | |
for connection, color in connections[: len(joints) - 1]: | |
joint1 = joints[connection[0]] | |
joint2 = joints[connection[1]] | |
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color) | |
ax.set_xlim([0, W]) | |
ax.set_ylim([0, H]) | |
ax.grid(False) | |
ax.set_axis_off() | |
ax.invert_yaxis() | |
# plt.subplots_adjust(wspace=0.01) | |
# plt.show() | |
buf = BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
plt.close() | |
# Convert BytesIO object to numpy array | |
buf.seek(0) | |
img_pil = Image.open(buf) | |
img_pil = img_pil.resize((H, W)) | |
numpy_img = np.array(img_pil) | |
return numpy_img | |
def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True): | |
"""Overlay mask on image for visualization purpose. | |
Args: | |
image (H, W, 3) or (H, W): input image | |
mask (H, W): mask to be overlaid | |
color: the color of overlaid mask | |
alpha: the transparency of the mask | |
""" | |
out = deepcopy(image) | |
img = deepcopy(image) | |
img[mask == 1] = color | |
if transparent: | |
out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out) | |
else: | |
out = img | |
return out | |
def scale_keypoint(keypoint, original_size, target_size): | |
"""Scale a keypoint based on the resizing of the image.""" | |
keypoint_copy = keypoint.copy() | |
keypoint_copy[:, 0] *= target_size[0] / original_size[0] | |
keypoint_copy[:, 1] *= target_size[1] / original_size[1] | |
return keypoint_copy | |
print("Configure...") | |
class HandDiffOpts: | |
run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5" | |
sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt" | |
log_dir: str = "/users/kchen157/scratch/log" | |
data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff" | |
image_size: tuple = (256, 256) | |
latent_size: tuple = (32, 32) | |
latent_dim: int = 4 | |
mask_bg: bool = False | |
kpts_form: str = "heatmap" | |
n_keypoints: int = 42 | |
n_mask: int = 1 | |
noise_steps: int = 1000 | |
test_sampling_steps: int = 250 | |
ddim_steps: int = 100 | |
ddim_discretize: str = "uniform" | |
ddim_eta: float = 0.0 | |
beta_start: float = 8.5e-4 | |
beta_end: float = 0.012 | |
latent_scaling_factor: float = 0.18215 | |
cfg_pose: float = 5.0 | |
cfg_appearance: float = 3.5 | |
batch_size: int = 25 | |
lr: float = 1e-5 | |
max_epochs: int = 500 | |
log_every_n_steps: int = 100 | |
limit_val_batches: int = 1 | |
n_gpu: int = 8 | |
num_nodes: int = 1 | |
precision: str = "16-mixed" | |
profiler: str = "simple" | |
swa_epoch_start: int = 10 | |
swa_lrs: float = 1e-3 | |
num_workers: int = 10 | |
n_val_samples: int = 4 | |
# load models | |
token = os.getenv("HF_TOKEN") | |
if NEW_MODEL: | |
opts = HandDiffOpts() | |
if MODEL_EPOCH == 7: | |
model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt' | |
elif MODEL_EPOCH == 6: | |
# model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt" | |
model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token) | |
elif MODEL_EPOCH == 4: | |
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt" | |
elif MODEL_EPOCH == 10: | |
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt" | |
else: | |
raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}") | |
# vae_path = './vae-ft-mse-840000-ema-pruned.ckpt' | |
vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token) | |
# sd_path = './sd-v1-4.ckpt' | |
print('Load diffusion model...') | |
diffusion = create_diffusion(str(opts.test_sampling_steps)) | |
model = vit.DiT_XL_2( | |
input_size=opts.latent_size[0], | |
latent_dim=opts.latent_dim, | |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask, | |
learn_sigma=True, | |
).to(device) | |
# ckpt_state_dict = torch.load(model_path)['model_state_dict'] | |
ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict'] | |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False) | |
model = model.to(device) | |
model.eval() | |
print(missing_keys, extra_keys) | |
assert len(missing_keys) == 0 | |
vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict'] | |
print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}") | |
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) | |
print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
autoencoder = autoencoder.to(device) | |
autoencoder.eval() | |
print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
assert len(missing_keys) == 0 | |
# else: | |
# opts = HandDiffOpts() | |
# model_path = './finetune_epoch=5-step=130000.ckpt' | |
# sd_path = './sd-v1-4.ckpt' | |
# print('Load diffusion model...') | |
# diffusion = create_diffusion(str(opts.test_sampling_steps)) | |
# model = vit.DiT_XL_2( | |
# input_size=opts.latent_size[0], | |
# latent_dim=opts.latent_dim, | |
# in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask, | |
# learn_sigma=True, | |
# ).to(device) | |
# ckpt_state_dict = torch.load(model_path)['state_dict'] | |
# dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')} | |
# vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')} | |
# missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False) | |
# model.eval() | |
# assert len(missing_keys) == 0 and len(extra_keys) == 0 | |
# autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device) | |
# missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
# autoencoder.eval() | |
# assert len(missing_keys) == 0 and len(extra_keys) == 0 | |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token) | |
sam_predictor = init_sam(ckpt_path=sam_path, device='cpu') | |
print("Mediapipe hand detector and SAM ready...") | |
mp_hands = mp.solutions.hands | |
hands = mp_hands.Hands( | |
static_image_mode=True, # Use False if image is part of a video stream | |
max_num_hands=2, # Maximum number of hands to detect | |
min_detection_confidence=0.1, | |
) | |
def get_ref_anno(ref): | |
if ref is None: | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
) | |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
img = ref["composite"][..., :3] | |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA) | |
keypts = np.zeros((42, 2)) | |
if REF_POSE_MASK: | |
mp_pose = hands.process(img) | |
detected = np.array([0, 0]) | |
start_idx = 0 | |
if mp_pose.multi_hand_landmarks: | |
# handedness is flipped assuming the input image is mirrored in MediaPipe | |
for hand_landmarks, handedness in zip( | |
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness | |
): | |
# actually right hand | |
if handedness.classification[0].label == "Left": | |
start_idx = 0 | |
detected[0] = 1 | |
# actually left hand | |
elif handedness.classification[0].label == "Right": | |
start_idx = 21 | |
detected[1] = 1 | |
for i, landmark in enumerate(hand_landmarks.landmark): | |
keypts[start_idx + i] = [ | |
landmark.x * opts.image_size[1], | |
landmark.y * opts.image_size[0], | |
] | |
sam_predictor.set_image(img) | |
l = keypts[:21].shape[0] | |
if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
input_point = np.array([keypts[0], keypts[21]]) | |
input_label = np.array([1, 1]) | |
elif keypts[0].sum() != 0: | |
input_point = np.array(keypts[:1]) | |
input_label = np.array([1]) | |
elif keypts[21].sum() != 0: | |
input_point = np.array(keypts[21:22]) | |
input_label = np.array([1]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
hand_mask = masks[0] | |
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None]) | |
ref_pose = visualize_hand(keypts, masked_img) | |
else: | |
raise gr.Error("No hands detected in the reference image.") | |
else: | |
hand_mask = np.zeros_like(img[:,:, 0]) | |
ref_pose = np.zeros_like(img) | |
print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}") | |
def make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device="cuda", | |
target_size=(256, 256), | |
latent_size=(32, 32), | |
): | |
image_transform = Compose( | |
[ | |
ToTensor(), | |
Resize(target_size), | |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
image = image_transform(img) | |
kpts_valid = check_keypoints_validity(keypts, target_size) | |
heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
)[None, ...] | |
mask = torch.tensor( | |
cv2.resize( | |
hand_mask.astype(int), | |
dsize=latent_size, | |
interpolation=cv2.INTER_NEAREST, | |
), | |
dtype=torch.float, | |
).unsqueeze(0)[None, ...] | |
return image[None, ...], heatmaps, mask | |
print(f"img.max(): {img.max()}, img.min(): {img.min()}") | |
image, heatmaps, mask = make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device="cuda", | |
target_size=opts.image_size, | |
latent_size=opts.latent_size, | |
) | |
print(f"image.max(): {image.max()}, image.min(): {image.min()}") | |
print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}") | |
print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() | |
print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}") | |
if not REF_POSE_MASK: | |
heatmaps = torch.zeros_like(heatmaps) | |
mask = torch.zeros_like(mask) | |
print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}") | |
print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}") | |
ref_cond = torch.cat([latent, heatmaps, mask], 1) | |
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") | |
return img, ref_pose, ref_cond | |
def get_target_anno(target): | |
if target is None: | |
return ( | |
gr.State.update(value=None), | |
gr.Image.update(value=None), | |
gr.State.update(value=None), | |
gr.State.update(value=None), | |
) | |
pose_img = target["composite"][..., :3] | |
pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA) | |
# detect keypoints | |
mp_pose = hands.process(pose_img) | |
target_keypts = np.zeros((42, 2)) | |
detected = np.array([0, 0]) | |
start_idx = 0 | |
if mp_pose.multi_hand_landmarks: | |
# handedness is flipped assuming the input image is mirrored in MediaPipe | |
for hand_landmarks, handedness in zip( | |
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness | |
): | |
# actually right hand | |
if handedness.classification[0].label == "Left": | |
start_idx = 0 | |
detected[0] = 1 | |
# actually left hand | |
elif handedness.classification[0].label == "Right": | |
start_idx = 21 | |
detected[1] = 1 | |
for i, landmark in enumerate(hand_landmarks.landmark): | |
target_keypts[start_idx + i] = [ | |
landmark.x * opts.image_size[1], | |
landmark.y * opts.image_size[0], | |
] | |
target_pose = visualize_hand(target_keypts, pose_img) | |
kpts_valid = check_keypoints_validity(target_keypts, opts.image_size) | |
target_heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(target_keypts, opts.image_size, opts.latent_size), | |
opts.latent_size, | |
var=1.0, | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
# device=device, | |
)[None, ...] | |
target_cond = torch.cat( | |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1 | |
) | |
else: | |
raise gr.Error("No hands detected in the target image.") | |
return pose_img, target_pose, target_cond, target_keypts | |
def get_mask_inpaint(ref): | |
inpaint_mask = np.array(ref["layers"][0])[..., -1] | |
inpaint_mask = cv2.resize( | |
inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA | |
) | |
inpaint_mask = (inpaint_mask >= 128).astype(np.uint8) | |
return inpaint_mask | |
def visualize_ref(crop, brush): | |
if crop is None or brush is None: | |
return None | |
inpainted = brush["layers"][0][..., -1] | |
img = crop["background"][..., :3] | |
img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA) | |
mask = inpainted < 128 | |
# img = img.astype(np.int32) | |
# img[mask, :] = img[mask, :] - 50 | |
# img[np.any(img<0, axis=-1)]=0 | |
# img = img.astype(np.uint8) | |
img = mask_image(img, mask) | |
return img | |
def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData): | |
if keypoints is None: | |
keypoints = [[], []] | |
kps = np.zeros((42, 2)) | |
if side == "right": | |
if len(keypoints[0]) == 21: | |
gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.") | |
else: | |
keypoints[0].append(list(evt.index)) | |
len_kps = len(keypoints[0]) | |
kps[:len_kps] = np.array(keypoints[0]) | |
elif side == "left": | |
if len(keypoints[1]) == 21: | |
gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.") | |
else: | |
keypoints[1].append(list(evt.index)) | |
len_kps = len(keypoints[1]) | |
kps[21 : 21 + len_kps] = np.array(keypoints[1]) | |
vis_hand = visualize_hand(kps, img, side, len_kps) | |
return vis_hand, keypoints | |
def undo_kps(img, keypoints, side: Literal["right", "left"]): | |
if keypoints is None: | |
return img, None | |
kps = np.zeros((42, 2)) | |
if side == "right": | |
if len(keypoints[0]) == 0: | |
return img, keypoints | |
keypoints[0].pop() | |
len_kps = len(keypoints[0]) | |
kps[:len_kps] = np.array(keypoints[0]) | |
elif side == "left": | |
if len(keypoints[1]) == 0: | |
return img, keypoints | |
keypoints[1].pop() | |
len_kps = len(keypoints[1]) | |
kps[21 : 21 + len_kps] = np.array(keypoints[1]) | |
vis_hand = visualize_hand(kps, img, side, len_kps) | |
return vis_hand, keypoints | |
def reset_kps(img, keypoints, side: Literal["right", "left"]): | |
if keypoints is None: | |
return img, None | |
if side == "right": | |
keypoints[0] = [] | |
elif side == "left": | |
keypoints[1] = [] | |
return img, keypoints | |
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg): | |
set_seed(seed) | |
z = torch.randn( | |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), | |
device=device, | |
) | |
print(f"z.device: {z.device}") | |
target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device) | |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device) | |
print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}") | |
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") | |
# novel view synthesis mode = off | |
nvs = torch.zeros(num_gen, dtype=torch.int, device=device) | |
z = torch.cat([z, z], 0) | |
model_kwargs = dict( | |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), | |
ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]), | |
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), | |
cfg_scale=cfg, | |
) | |
samples, _ = diffusion.p_sample_loop( | |
model.forward_with_cfg, | |
z.shape, | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=device, | |
).chunk(2) | |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) | |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) | |
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) | |
results = [] | |
results_pose = [] | |
for i in range(MAX_N): | |
if i < num_gen: | |
results.append(sampled_images[i]) | |
results_pose.append(visualize_hand(target_keypts, sampled_images[i])) | |
else: | |
results.append(placeholder) | |
results_pose.append(placeholder) | |
print(f"results[0].max(): {results[0].max()}") | |
return results, results_pose | |
# @spaces.GPU(duration=120) | |
def ready_sample(img_ori, inpaint_mask, keypts): | |
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA) | |
sam_predictor.set_image(img) | |
if len(keypts[0]) == 0: | |
keypts[0] = np.zeros((21, 2)) | |
elif len(keypts[0]) == 21: | |
keypts[0] = np.array(keypts[0], dtype=np.float32) | |
else: | |
gr.Info("Number of right hand keypoints should be either 0 or 21.") | |
return None, None | |
if len(keypts[1]) == 0: | |
keypts[1] = np.zeros((21, 2)) | |
elif len(keypts[1]) == 21: | |
keypts[1] = np.array(keypts[1], dtype=np.float32) | |
else: | |
gr.Info("Number of left hand keypoints should be either 0 or 21.") | |
return None, None | |
keypts = np.concatenate(keypts, axis=0) | |
keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size) | |
# if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
# input_point = np.array([keypts[0], keypts[21]]) | |
# # input_point = keypts | |
# input_label = np.array([1, 1]) | |
# # input_label = np.ones_like(input_point[:, 0]) | |
# elif keypts[0].sum() != 0: | |
# input_point = np.array(keypts[:1]) | |
# # input_point = keypts[:21] | |
# input_label = np.array([1]) | |
# # input_label = np.ones_like(input_point[:21, 0]) | |
# elif keypts[21].sum() != 0: | |
# input_point = np.array(keypts[21:22]) | |
# # input_point = keypts[21:] | |
# input_label = np.array([1]) | |
# # input_label = np.ones_like(input_point[21:, 0]) | |
box_shift_ratio = 0.5 | |
box_size_factor = 1.2 | |
if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
input_point = np.array(keypts) | |
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)]) | |
elif keypts[0].sum() != 0: | |
input_point = np.array(keypts[:21]) | |
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)]) | |
elif keypts[21].sum() != 0: | |
input_point = np.array(keypts[21:]) | |
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)]) | |
else: | |
raise ValueError( | |
"Something wrong. If no hand detected, it should not reach here." | |
) | |
input_label = np.ones_like(input_point[:, 0]).astype(np.int32) | |
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio) | |
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
box=input_box[None, :], | |
multimask_output=False, | |
) | |
hand_mask = masks[0] | |
inpaint_latent_mask = torch.tensor( | |
cv2.resize( | |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST | |
), | |
dtype=torch.float, | |
# device=device, | |
).unsqueeze(0)[None, ...] | |
def make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device=device, | |
target_size=(256, 256), | |
latent_size=(32, 32), | |
): | |
image_transform = Compose( | |
[ | |
ToTensor(), | |
Resize(target_size), | |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
image = image_transform(img) | |
kpts_valid = check_keypoints_validity(keypts, target_size) | |
heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
# device=device, | |
)[None, ...] | |
mask = torch.tensor( | |
cv2.resize( | |
hand_mask.astype(int), | |
dsize=latent_size, | |
interpolation=cv2.INTER_NEAREST, | |
), | |
dtype=torch.float, | |
# device=device, | |
).unsqueeze(0)[None, ...] | |
return image[None, ...], heatmaps, mask | |
image, heatmaps, mask = make_ref_cond( | |
img, | |
keypts, | |
hand_mask * (1 - inpaint_mask), | |
device=device, | |
target_size=opts.image_size, | |
latent_size=opts.latent_size, | |
) | |
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() | |
target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1) | |
ref_cond = torch.cat([latent, heatmaps, mask], 1) | |
ref_cond = torch.zeros_like(ref_cond) | |
img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST) | |
assert mask.max() == 1 | |
vis_mask32 = mask_image( | |
img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False | |
).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy() | |
assert np.unique(inpaint_mask).shape[0] <= 2 | |
assert hand_mask.dtype == bool | |
mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask) | |
vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype( | |
np.uint8 | |
) # 1 - mask256 | |
return ( | |
ref_cond, | |
target_cond, | |
latent, | |
inpaint_latent_mask, | |
keypts, | |
vis_mask32, | |
vis_mask256, | |
) | |
def switch_mask_size(radio): | |
if radio == "256x256": | |
out = (gr.update(visible=False), gr.update(visible=True)) | |
elif radio == "latent size (32x32)": | |
out = (gr.update(visible=True), gr.update(visible=False)) | |
return out | |
def sample_inpaint( | |
ref_cond, | |
target_cond, | |
latent, | |
inpaint_latent_mask, | |
keypts, | |
num_gen, | |
seed, | |
cfg, | |
quality, | |
): | |
set_seed(seed) | |
N = num_gen | |
jump_length = 10 | |
jump_n_sample = quality | |
cfg_scale = cfg | |
z = torch.randn( | |
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device | |
) | |
target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device) | |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device) | |
# novel view synthesis mode = off | |
nvs = torch.zeros(N, dtype=torch.int, device=device) | |
z = torch.cat([z, z], 0) | |
model_kwargs = dict( | |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]), | |
ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]), | |
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), | |
cfg_scale=cfg_scale, | |
) | |
samples, _ = diffusion.inpaint_p_sample_loop( | |
model.forward_with_cfg, | |
z.shape, | |
latent.to(z.device), | |
inpaint_latent_mask.to(z.device), | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=z.device, | |
jump_length=jump_length, | |
jump_n_sample=jump_n_sample, | |
).chunk(2) | |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) | |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) | |
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) | |
# visualize | |
results = [] | |
results_pose = [] | |
for i in range(FIX_MAX_N): | |
if i < num_gen: | |
results.append(sampled_images[i]) | |
results_pose.append(visualize_hand(keypts, sampled_images[i])) | |
else: | |
results.append(placeholder) | |
results_pose.append(placeholder) | |
return results, results_pose | |
def flip_hand( | |
img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None | |
): | |
if cond is None: # clear clicked | |
return None, None, None, None | |
img["composite"] = img["composite"][:, ::-1, :] | |
img["background"] = img["background"][:, ::-1, :] | |
img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]] | |
pose_img = pose_img[:, ::-1, :] | |
cond = cond.flip(-1) | |
if keypts is not None: # cond is target_cond | |
if keypts[:21, :].sum() != 0: | |
keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0] | |
# keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1] | |
if keypts[21:, :].sum() != 0: | |
keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0] | |
# keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1] | |
return img, pose_img, cond, keypts | |
def resize_to_full(img): | |
img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH)) | |
img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH)) | |
img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]] | |
return img | |
def clear_all(): | |
return ( | |
None, | |
None, | |
False, | |
None, | |
None, | |
False, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
1, | |
42, | |
3.0, | |
) | |
def fix_clear_all(): | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
1, | |
# (0,0), | |
42, | |
3.0, | |
10, | |
) | |
def enable_component(image1, image2): | |
if image1 is None or image2 is None: | |
return gr.update(interactive=False) | |
if "background" in image1 and "layers" in image1 and "composite" in image1: | |
if ( | |
image1["background"].sum() == 0 | |
and (sum([im.sum() for im in image1["layers"]]) == 0) | |
and image1["composite"].sum() == 0 | |
): | |
return gr.update(interactive=False) | |
if "background" in image2 and "layers" in image2 and "composite" in image2: | |
if ( | |
image2["background"].sum() == 0 | |
and (sum([im.sum() for im in image2["layers"]]) == 0) | |
and image2["composite"].sum() == 0 | |
): | |
return gr.update(interactive=False) | |
return gr.update(interactive=True) | |
def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left): | |
if kpts is None: | |
kpts = [[], []] | |
if "Right hand" not in checkbox: | |
kpts[0] = [] | |
vis_right = img_clean | |
update_right = gr.update(visible=False) | |
update_r_info = gr.update(visible=False) | |
else: | |
vis_right = img_pose_right | |
update_right = gr.update(visible=True) | |
update_r_info = gr.update(visible=True) | |
if "Left hand" not in checkbox: | |
kpts[1] = [] | |
vis_left = img_clean | |
update_left = gr.update(visible=False) | |
update_l_info = gr.update(visible=False) | |
else: | |
vis_left = img_pose_left | |
update_left = gr.update(visible=True) | |
update_l_info = gr.update(visible=True) | |
return ( | |
kpts, | |
vis_right, | |
vis_left, | |
update_right, | |
update_right, | |
update_right, | |
update_left, | |
update_left, | |
update_left, | |
update_r_info, | |
update_l_info, | |
) | |
LENGTH = 480 | |
example_imgs = [ | |
[ | |
"sample_images/sample1.jpg", | |
], | |
[ | |
"sample_images/sample2.jpg", | |
], | |
[ | |
"sample_images/sample3.jpg", | |
], | |
[ | |
"sample_images/sample4.jpg", | |
], | |
[ | |
"sample_images/sample5.jpg", | |
], | |
[ | |
"sample_images/sample6.jpg", | |
], | |
[ | |
"sample_images/sample7.jpg", | |
], | |
[ | |
"sample_images/sample8.jpg", | |
], | |
[ | |
"sample_images/sample9.jpg", | |
], | |
[ | |
"sample_images/sample10.jpg", | |
], | |
[ | |
"sample_images/sample11.jpg", | |
], | |
["pose_images/pose1.jpg"], | |
["pose_images/pose2.jpg"], | |
["pose_images/pose3.jpg"], | |
["pose_images/pose4.jpg"], | |
["pose_images/pose5.jpg"], | |
["pose_images/pose6.jpg"], | |
["pose_images/pose7.jpg"], | |
["pose_images/pose8.jpg"], | |
] | |
fix_example_imgs = [ | |
["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"], | |
["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"], | |
["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"], | |
["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"], | |
["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"], | |
["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"], | |
["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"], | |
["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"], | |
["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"], | |
["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"], | |
["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"], | |
["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"], | |
["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"], | |
] | |
custom_css = """ | |
.gradio-container .examples img { | |
width: 240px !important; | |
height: 240px !important; | |
} | |
""" | |
_HEADER_ = ''' | |
<h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1> | |
<h2> | |
📝<a href='https://arxiv.org/abs/2412.02690' target='_blank'>Paper</a> | |
📢<a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank'>Project</a> | |
</h2> | |
''' | |
_CITE_ = r""" | |
``` | |
@article{chen2024foundhand, | |
title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation}, | |
author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath}, | |
journal={arXiv preprint arXiv:2412.02690}, | |
year={2024} | |
} | |
``` | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown(_HEADER_) | |
with gr.Tab("Edit Hand Poses"): | |
ref_img = gr.State(value=None) | |
ref_cond = gr.State(value=None) | |
keypts = gr.State(value=None) | |
target_img = gr.State(value=None) | |
target_cond = gr.State(value=None) | |
target_keypts = gr.State(value=None) | |
dump = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Reference</p>""" | |
) | |
gr.Markdown("""<p style="text-align: center;"><br></p>""") | |
ref = gr.ImageEditor( | |
type="numpy", | |
label="Reference", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
brush=False, | |
layers=False, | |
crop_size="1:1", | |
) | |
ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False) | |
ref_pose = gr.Image( | |
type="numpy", | |
label="Reference Pose", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
) | |
ref_flip = gr.Checkbox( | |
value=False, label="Flip Handedness (Reference)", interactive=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">2. Target</p>""" | |
) | |
target = gr.ImageEditor( | |
type="numpy", | |
label="Target", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
brush=False, | |
layers=False, | |
crop_size="1:1", | |
) | |
target_finish_crop = gr.Button( | |
value="Finish Cropping", interactive=False | |
) | |
target_pose = gr.Image( | |
type="numpy", | |
label="Target Pose", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
) | |
target_flip = gr.Checkbox( | |
value=False, label="Flip Handedness (Target)", interactive=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">3. Result</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">Run is enabled after the images have been processed</p>""" | |
) | |
run = gr.Button(value="Run", interactive=False) | |
gr.Markdown( | |
"""<p style="text-align: center;">~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>""" | |
) | |
results = gr.Gallery( | |
type="numpy", | |
label="Results", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
results_pose = gr.Gallery( | |
type="numpy", | |
label="Results Pose", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
clear = gr.ClearButton() | |
with gr.Row(): | |
n_generation = gr.Slider( | |
label="Number of generations", | |
value=1, | |
minimum=1, | |
maximum=MAX_N, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
value=42, | |
minimum=0, | |
maximum=10000, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
cfg = gr.Slider( | |
label="Classifier free guidance scale", | |
value=2.5, | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
randomize=False, | |
interactive=True, | |
) | |
ref.change(enable_component, [ref, ref], ref_finish_crop) | |
ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond]) | |
ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip) | |
ref_flip.select( | |
flip_hand, [ref, ref_pose, ref_cond], [ref, ref_pose, ref_cond, dump] | |
) | |
target.change(enable_component, [target, target], target_finish_crop) | |
target_finish_crop.click( | |
get_target_anno, | |
[target], | |
[target_img, target_pose, target_cond, target_keypts], | |
) | |
target_pose.change(enable_component, [target_img, target_pose], target_flip) | |
target_flip.select( | |
flip_hand, | |
[target, target_pose, target_cond, target_keypts], | |
[target, target_pose, target_cond, target_keypts], | |
) | |
ref_pose.change(enable_component, [ref_pose, target_pose], run) | |
target_pose.change(enable_component, [ref_pose, target_pose], run) | |
run.click( | |
sample_diff, | |
[ref_cond, target_cond, target_keypts, n_generation, seed, cfg], | |
[results, results_pose], | |
) | |
clear.click( | |
clear_all, | |
[], | |
[ | |
ref, | |
ref_pose, | |
ref_flip, | |
target, | |
target_pose, | |
target_flip, | |
results, | |
results_pose, | |
ref_img, | |
ref_cond, | |
# mask, | |
target_img, | |
target_cond, | |
target_keypts, | |
n_generation, | |
seed, | |
cfg, | |
], | |
) | |
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""") | |
with gr.Tab("Reference"): | |
with gr.Row(): | |
gr.Examples(example_imgs, [ref], examples_per_page=20) | |
with gr.Tab("Target"): | |
with gr.Row(): | |
gr.Examples(example_imgs, [target], examples_per_page=20) | |
with gr.Tab("Fix Hands"): | |
fix_inpaint_mask = gr.State(value=None) | |
fix_original = gr.State(value=None) | |
fix_img = gr.State(value=None) | |
fix_kpts = gr.State(value=None) | |
fix_kpts_np = gr.State(value=None) | |
fix_ref_cond = gr.State(value=None) | |
fix_target_cond = gr.State(value=None) | |
fix_latent = gr.State(value=None) | |
fix_inpaint_latent = gr.State(value=None) | |
# fix_size_memory = gr.State(value=(0, 0)) | |
gr.Markdown("""<p style="text-align: center; font-size: 25px; font-weight: bold; ">⚠️ Note</p>""") | |
gr.Markdown("""<p>"Fix Hands" with A100 needs around 6 mins, which is beyond the ZeroGPU quota (5 mins). Please either purchase additional gpus from Hugging Face or wait for us to open-source our code soon so that you can use your own gpus🙏 </p>""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">Crop the image around the hand.<br>Then, brush area (e.g., wrong finger) that needs to be fixed.</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>""" | |
) | |
fix_crop = gr.ImageEditor( | |
type="numpy", | |
sources=["upload", "webcam", "clipboard"], | |
label="Image crop", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
layers=False, | |
crop_size="1:1", | |
brush=False, | |
image_mode="RGBA", | |
container=False, | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>""" | |
) | |
fix_ref = gr.ImageEditor( | |
type="numpy", | |
label="Image brush", | |
sources=(), | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
layers=False, | |
transforms=("brush"), | |
brush=gr.Brush( | |
colors=["rgb(255, 255, 255)"], default_size=20 | |
), # 204, 50, 50 | |
image_mode="RGBA", | |
container=False, | |
interactive=False, | |
) | |
fix_finish_crop = gr.Button( | |
value="Finish Croping & Brushing", interactive=False | |
) | |
gr.Markdown( | |
"""<p style="text-align: left; font-size: 20px; font-weight: bold; ">OpenPose keypoints convention</p>""" | |
) | |
fix_openpose = gr.Image( | |
value="openpose.png", | |
type="numpy", | |
label="OpenPose keypoints convention", | |
show_label=True, | |
height=LENGTH // 3 * 2, | |
width=LENGTH // 3 * 2, | |
interactive=False, | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\" on the bottom left.</p>""" | |
) | |
fix_checkbox = gr.CheckboxGroup( | |
["Right hand", "Left hand"], | |
# value=["Right hand", "Left hand"], | |
label="Hand side", | |
info="Which side this hand is? Could be both.", | |
interactive=False, | |
) | |
fix_kp_r_info = gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""", | |
visible=False, | |
) | |
fix_kp_right = gr.Image( | |
type="numpy", | |
label="Keypoint Selection (right hand)", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
visible=False, | |
sources=[], | |
) | |
with gr.Row(): | |
fix_undo_right = gr.Button( | |
value="Undo", interactive=False, visible=False | |
) | |
fix_reset_right = gr.Button( | |
value="Reset", interactive=False, visible=False | |
) | |
fix_kp_l_info = gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""", | |
visible=False | |
) | |
fix_kp_left = gr.Image( | |
type="numpy", | |
label="Keypoint Selection (left hand)", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
visible=False, | |
sources=[], | |
) | |
with gr.Row(): | |
fix_undo_left = gr.Button( | |
value="Undo", interactive=False, visible=False | |
) | |
fix_reset_left = gr.Button( | |
value="Reset", interactive=False, visible=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">In Fix Hands, not segmentation mask, but only inpaint mask is used.</p>""" | |
) | |
fix_ready = gr.Button(value="Ready", interactive=False) | |
fix_mask_size = gr.Radio( | |
["256x256", "latent size (32x32)"], | |
label="Visualized inpaint mask size", | |
interactive=False, | |
value="256x256", | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Visualized inpaint masks</p>""" | |
) | |
fix_vis_mask32 = gr.Image( | |
type="numpy", | |
label=f"Visualized {opts.latent_size} Inpaint Mask", | |
show_label=True, | |
height=opts.latent_size, | |
width=opts.latent_size, | |
interactive=False, | |
visible=False, | |
) | |
fix_vis_mask256 = gr.Image( | |
type="numpy", | |
label=f"Visualized {opts.image_size} Inpaint Mask", | |
visible=True, | |
show_label=True, | |
height=opts.image_size, | |
width=opts.image_size, | |
interactive=False, | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>""" | |
) | |
fix_run = gr.Button(value="Run", interactive=False) | |
gr.Markdown( | |
"""<p style="text-align: center;">>3min and ~24GB per generation</p>""" | |
) | |
fix_result = gr.Gallery( | |
type="numpy", | |
label="Results", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=FIX_MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
fix_result_pose = gr.Gallery( | |
type="numpy", | |
label="Results Pose", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=FIX_MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
fix_clear = gr.ClearButton() | |
gr.Markdown( | |
"[NOTE] Currently, Number of generation > 1 could lead to out-of-memory" | |
) | |
with gr.Row(): | |
fix_n_generation = gr.Slider( | |
label="Number of generations", | |
value=1, | |
minimum=1, | |
maximum=FIX_MAX_N, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_seed = gr.Slider( | |
label="Seed", | |
value=42, | |
minimum=0, | |
maximum=10000, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_cfg = gr.Slider( | |
label="Classifier free guidance scale", | |
value=3.0, | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_quality = gr.Slider( | |
label="Quality", | |
value=10, | |
minimum=1, | |
maximum=10, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref) | |
fix_crop.change(resize_to_full, fix_crop, fix_ref) | |
fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop) | |
fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask]) | |
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right]) | |
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left]) | |
fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original]) | |
fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img]) | |
fix_img.change(lambda x: x, [fix_img], [fix_kp_right]) | |
fix_img.change(lambda x: x, [fix_img], [fix_kp_left]) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready | |
) | |
# fix_inpaint_mask.change( | |
# enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run | |
# ) | |
fix_checkbox.select( | |
set_visible, | |
[fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left], | |
[ | |
fix_kpts, | |
fix_kp_right, | |
fix_kp_left, | |
fix_kp_right, | |
fix_undo_right, | |
fix_reset_right, | |
fix_kp_left, | |
fix_undo_left, | |
fix_reset_left, | |
fix_kp_r_info, | |
fix_kp_l_info, | |
], | |
) | |
fix_kp_right.select( | |
get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_undo_right.click( | |
undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_reset_right.click( | |
reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_kp_left.select( | |
get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
fix_undo_left.click( | |
undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
fix_reset_left.click( | |
reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
# fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run]) | |
# fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose]) | |
fix_vis_mask32.change( | |
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run | |
) | |
fix_vis_mask32.change( | |
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size | |
) | |
fix_ready.click( | |
ready_sample, | |
[fix_original, fix_inpaint_mask, fix_kpts], | |
[ | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_kpts_np, | |
fix_vis_mask32, | |
fix_vis_mask256, | |
], | |
) | |
fix_mask_size.select( | |
switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256] | |
) | |
fix_run.click( | |
sample_inpaint, | |
[ | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_kpts_np, | |
fix_n_generation, | |
fix_seed, | |
fix_cfg, | |
fix_quality, | |
], | |
[fix_result, fix_result_pose], | |
) | |
fix_clear.click( | |
fix_clear_all, | |
[], | |
[ | |
fix_crop, | |
fix_ref, | |
fix_kp_right, | |
fix_kp_left, | |
fix_result, | |
fix_result_pose, | |
fix_inpaint_mask, | |
fix_original, | |
fix_img, | |
fix_vis_mask32, | |
fix_vis_mask256, | |
fix_kpts, | |
fix_kpts_np, | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_n_generation, | |
# fix_size_memory, | |
fix_seed, | |
fix_cfg, | |
fix_quality, | |
], | |
) | |
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""") | |
fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False) | |
fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False) | |
with gr.Column(): | |
fix_example = gr.Examples( | |
fix_example_imgs, | |
# run_on_click=True, | |
# fn=parse_fix_example, | |
# inputs=[fix_dump_ex, fix_dump_ex_masked], | |
# outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask], | |
inputs=[fix_crop], | |
examples_per_page=20, | |
) | |
gr.Markdown("<h1>Citation</h1>") | |
gr.Markdown(_CITE_) | |
# print("Ready to launch..") | |
# _, _, shared_url = demo.queue().launch( | |
# share=True, server_name="0.0.0.0", server_port=7739 | |
# ) | |
demo.launch(share=True) | |