Spaces:
Sleeping
Sleeping
File size: 4,261 Bytes
b09c1e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import cv2
import sys
import argparse
import numpy as np
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from typing import Any, Dict, List
from sam_segment import predict_masks_with_sam
from stable_diffusion_inpaint import fill_img_with_sd
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
show_mask, show_points
def setup_args(parser):
parser.add_argument(
"--input_img", type=str, required=True,
help="Path to a single input img",
)
parser.add_argument(
"--point_coords", type=float, nargs='+', required=True,
help="The coordinate of the point prompt, [coord_W coord_H].",
)
parser.add_argument(
"--point_labels", type=int, nargs='+', required=True,
help="The labels of the point prompt, 1 or 0.",
)
parser.add_argument(
"--text_prompt", type=str, required=True,
help="Text prompt",
)
parser.add_argument(
"--dilate_kernel_size", type=int, default=None,
help="Dilate kernel size. Default: None",
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--sam_model_type", type=str,
default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
help="The type of sam model to load. Default: 'vit_h"
)
parser.add_argument(
"--sam_ckpt", type=str, required=True,
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument(
"--seed", type=int,
help="Specify seed for reproducibility.",
)
parser.add_argument(
"--deterministic", action="store_true",
help="Use deterministic algorithms for reproducibility.",
)
if __name__ == "__main__":
"""Example usage:
python fill_anything.py \
--input_img FA_demo/FA1_dog.png \
--point_coords 750 500 \
--point_labels 1 \
--text_prompt "a teddy bear on a bench" \
--dilate_kernel_size 15 \
--output_dir ./results \
--sam_model_type "vit_h" \
--sam_ckpt sam_vit_h_4b8939.pth
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
img = load_img_to_array(args.input_img)
masks, _, _ = predict_masks_with_sam(
img,
[args.point_coords],
args.point_labels,
model_type=args.sam_model_type,
ckpt_p=args.sam_ckpt,
device=device,
)
masks = masks.astype(np.uint8) * 255
# dilate mask to avoid unmasked edge effect
if args.dilate_kernel_size is not None:
masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
# visualize the segmentation results
img_stem = Path(args.input_img).stem
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
for idx, mask in enumerate(masks):
# path to the results
mask_p = out_dir / f"mask_{idx}.png"
img_points_p = out_dir / f"with_points.png"
img_mask_p = out_dir / f"with_{Path(mask_p).name}"
# save the mask
save_array_to_img(mask, mask_p)
# save the pointed and masked image
dpi = plt.rcParams['figure.dpi']
height, width = img.shape[:2]
plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
plt.imshow(img)
plt.axis('off')
show_points(plt.gca(), [args.point_coords], args.point_labels,
size=(width*0.04)**2)
plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
show_mask(plt.gca(), mask, random_color=False)
plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
plt.close()
# fill the masked image
for idx, mask in enumerate(masks):
if args.seed is not None:
torch.manual_seed(args.seed)
mask_p = out_dir / f"mask_{idx}.png"
img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
img_filled = fill_img_with_sd(
img, mask, args.text_prompt, device=device)
save_array_to_img(img_filled, img_filled_p) |