import argparse import os import sys import cv2 import glob import numpy as np import torch import torch.nn.functional as F import transformers from transformers import AutoTokenizer, CLIPImageProcessor from model.LISA import LISA from model.segment_anything.utils.transforms import ResizeLongestSide from utils.conversation import get_default_conv_template def parse_args(args): parser = argparse.ArgumentParser(description="LISA chat") parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v0") parser.add_argument("--vis_save_path", default="./vis_output", type=str) parser.add_argument( "--precision", default="bf16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image-size", default=1024, type=int, help="image size") parser.add_argument("--model-max-length", default=512, type=int) parser.add_argument("--lora-r", default=-1, type=int) parser.add_argument( "--vision-tower", default="openai/clip-vit-large-patch14", type=str ) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) return parser.parse_args(args) def preprocess( x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, ) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - pixel_mean) / pixel_std # Pad h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def main(args): args = parse_args(args) os.makedirs(args.vis_save_path, exist_ok=True) # Create model tokenizer = transformers.AutoTokenizer.from_pretrained( args.version, cache_dir=None, model_max_length=args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token num_added_tokens = tokenizer.add_tokens("[SEG]") ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids args.seg_token_idx = ret_token_idx[0] model = LISA( args.local_rank, args.seg_token_idx, tokenizer, args.version, args.lora_r, args.precision, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, ) if os.path.exists(args.version): model_dir = args.version else: # hack for cached pre-trained weights user_name, model_name = args.version.split("/") cache_dir = "{}/.cache/huggingface/hub/models--{}--{}".format(os.environ['HOME'], user_name, model_name) if os.path.exists(cache_dir): model1_dir = glob.glob("{}/snapshots/*/pytorch_model-visual_model.bin".format(cache_dir)) model2_dir = glob.glob("{}/snapshots/*/pytorch_model-text_hidden_fcs.bin".format(cache_dir)) if len(model1_dir) == 0 or len(model2_dir) == 0: raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format( cache_dir )) model1_dir = ["/".join(x.split("/")[:-1]) for x in model1_dir] model2_dir = ["/".join(x.split("/")[:-1]) for x in model2_dir] model_dir = list(set(model1_dir).intersection(set(model2_dir))) if len(model_dir) == 0: raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format( cache_dir )) model_dir = model_dir[0] else: raise ValueError("The path {} does not exists.".format(cache_dir)) weight = {} visual_model_weight = torch.load( os.path.join(model_dir, "pytorch_model-visual_model.bin") ) text_hidden_fcs_weight = torch.load( os.path.join(model_dir, "pytorch_model-text_hidden_fcs.bin") ) weight.update(visual_model_weight) weight.update(text_hidden_fcs_weight) missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False) if args.precision == "bf16": model = model.bfloat16().cuda() elif args.precision == "fp16": import deepspeed model_engine = deepspeed.init_inference( model=model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) model = model_engine.module else: model = model.float().cuda() DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" image_token_len = 256 clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower) transform = ResizeLongestSide(args.image_size) while True: conv = get_default_conv_template("vicuna").copy() conv.messages = [] prompt = input("Please input your prompt: ") prompt = DEFAULT_IMAGE_TOKEN + " " + prompt replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() image_path = input("Please input the image path: ") if not os.path.exists(image_path): print("File not found in {}".format(image_path)) continue image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) original_size_list = [image.shape[:2]] if args.precision == "bf16": images_clip = ( clip_image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() .bfloat16() ) elif args.precision == "fp16": images_clip = ( clip_image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() .half() ) else: images_clip = ( clip_image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() .float() ) images = transform.apply_image(image) resize_list = [images.shape[:2]] if args.precision == "bf16": images = ( preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() .bfloat16() ) elif args.precision == "fp16": images = ( preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() .half() ) else: images = ( preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() .float() ) input_ids = tokenizer(prompt).input_ids input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda() output_ids, pred_masks = model.evaluate( images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer, ) text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False) text_output = ( text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "") .replace("\n", "") .replace(" ", "") ) print("text_output: ", text_output) for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 save_path = "{}/{}_mask_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) cv2.imwrite(save_path, pred_mask * 100) print("{} has been saved.".format(save_path)) save_path = "{}/{}_masked_img_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) save_img = image.copy() save_img[pred_mask] = ( image * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask] save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) cv2.imwrite(save_path, save_img) print("{} has been saved.".format(save_path)) if __name__ == "__main__": main(sys.argv[1:])