|
import argparse |
|
import os |
|
import sys |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
from transformers import AutoTokenizer, BitsAndBytesConfig |
|
from model.segment_anything.utils.transforms import ResizeLongestSide |
|
|
|
|
|
|
|
def parse_args(args): |
|
parser = argparse.ArgumentParser(description="EVF infer") |
|
parser.add_argument("--version", required=True) |
|
parser.add_argument("--vis_save_path", default="./infer", type=str) |
|
parser.add_argument( |
|
"--precision", |
|
default="fp16", |
|
type=str, |
|
choices=["fp32", "bf16", "fp16"], |
|
help="precision for inference", |
|
) |
|
parser.add_argument("--image_size", default=224, type=int, help="image size") |
|
parser.add_argument("--model_max_length", default=512, type=int) |
|
|
|
parser.add_argument("--local-rank", default=0, type=int, help="node rank") |
|
parser.add_argument("--load_in_8bit", action="store_true", default=False) |
|
parser.add_argument("--load_in_4bit", action="store_true", default=False) |
|
parser.add_argument("--model_type", default="ori", choices=["ori", "effi", "sam2"]) |
|
parser.add_argument("--image_path", type=str, default="assets/zebra.jpg") |
|
parser.add_argument("--prompt", type=str, default="zebra top left") |
|
|
|
return parser.parse_args(args) |
|
|
|
|
|
def sam_preprocess( |
|
x: np.ndarray, |
|
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), |
|
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), |
|
img_size=1024, |
|
model_type="ori") -> torch.Tensor: |
|
''' |
|
preprocess of Segment Anything Model, including scaling, normalization and padding. |
|
preprocess differs between SAM and Effi-SAM, where Effi-SAM use no padding. |
|
input: ndarray |
|
output: torch.Tensor |
|
''' |
|
assert img_size==1024, \ |
|
"both SAM and Effi-SAM receive images of size 1024^2, don't change this setting unless you're sure that your employed model works well with another size." |
|
|
|
|
|
if model_type=="ori": |
|
x = ResizeLongestSide(img_size).apply_image(x) |
|
h, w = resize_shape = x.shape[:2] |
|
x = torch.from_numpy(x).permute(2,0,1).contiguous() |
|
x = (x - pixel_mean) / pixel_std |
|
|
|
padh = img_size - h |
|
padw = img_size - w |
|
x = F.pad(x, (0, padw, 0, padh)) |
|
else: |
|
x = torch.from_numpy(x).permute(2,0,1).contiguous() |
|
x = F.interpolate(x.unsqueeze(0), (img_size, img_size), mode="bilinear", align_corners=False).squeeze(0) |
|
x = (x - pixel_mean) / pixel_std |
|
resize_shape = None |
|
|
|
return x, resize_shape |
|
|
|
def beit3_preprocess(x: np.ndarray, img_size=224) -> torch.Tensor: |
|
''' |
|
preprocess for BEIT-3 model. |
|
input: ndarray |
|
output: torch.Tensor |
|
''' |
|
beit_preprocess = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), |
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|
]) |
|
return beit_preprocess(x) |
|
|
|
def init_models(args): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.version, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
torch_dtype = torch.float32 |
|
if args.precision == "bf16": |
|
torch_dtype = torch.bfloat16 |
|
elif args.precision == "fp16": |
|
torch_dtype = torch.half |
|
|
|
kwargs = {"torch_dtype": torch_dtype} |
|
if args.load_in_4bit: |
|
kwargs.update( |
|
{ |
|
"torch_dtype": torch.half, |
|
"quantization_config": BitsAndBytesConfig( |
|
llm_int8_skip_modules=["visual_model"], |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
), |
|
} |
|
) |
|
elif args.load_in_8bit: |
|
kwargs.update( |
|
{ |
|
"torch_dtype": torch.half, |
|
"quantization_config": BitsAndBytesConfig( |
|
llm_int8_skip_modules=["visual_model"], |
|
load_in_8bit=True, |
|
), |
|
} |
|
) |
|
|
|
if args.model_type=="ori": |
|
from model.evf_sam import EvfSamModel |
|
model = EvfSamModel.from_pretrained( |
|
args.version, low_cpu_mem_usage=True, **kwargs |
|
) |
|
elif args.model_type=="effi": |
|
from model.evf_effisam import EvfEffiSamModel |
|
model = EvfEffiSamModel.from_pretrained( |
|
args.version, low_cpu_mem_usage=True, **kwargs |
|
) |
|
elif args.model_type=="sam2": |
|
from model.evf_sam2 import EvfSam2Model |
|
model = EvfSam2Model.from_pretrained( |
|
args.version, low_cpu_mem_usage=True, **kwargs |
|
) |
|
|
|
if (not args.load_in_4bit) and (not args.load_in_8bit): |
|
model = model.cuda() |
|
model.eval() |
|
|
|
return tokenizer, model |
|
|
|
def main(args): |
|
args = parse_args(args) |
|
|
|
|
|
image_path = args.image_path |
|
if not os.path.exists(image_path): |
|
print("File not found in {}".format(image_path)) |
|
exit() |
|
prompt = args.prompt |
|
|
|
os.makedirs(args.vis_save_path, exist_ok=True) |
|
save_path = "{}/{}_vis.png".format( |
|
args.vis_save_path, os.path.basename(image_path).split(".")[0] |
|
) |
|
|
|
|
|
tokenizer, model = init_models(args) |
|
|
|
image_np = cv2.imread(image_path) |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
|
original_size_list = [image_np.shape[:2]] |
|
|
|
image_beit = beit3_preprocess(image_np, args.image_size).to(dtype=model.dtype, device=model.device) |
|
|
|
image_sam, resize_shape = sam_preprocess(image_np, model_type=args.model_type) |
|
image_sam = image_sam.to(dtype=model.dtype, device=model.device) |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device=model.device) |
|
|
|
|
|
pred_mask = model.inference( |
|
image_sam.unsqueeze(0), |
|
image_beit.unsqueeze(0), |
|
input_ids, |
|
resize_list=[resize_shape], |
|
original_size_list=original_size_list, |
|
) |
|
pred_mask = pred_mask.detach().cpu().numpy()[0] |
|
pred_mask = pred_mask > 0 |
|
|
|
|
|
save_img = image_np.copy() |
|
save_img[pred_mask] = ( |
|
image_np * 0.5 |
|
+ pred_mask[:, :, None].astype(np.uint8) * np.array([50, 120, 220]) * 0.5 |
|
)[pred_mask] |
|
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) |
|
|
|
cv2.imwrite(save_path, save_img) |
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1:]) |