lisa-on-cuda / chat.py
tianzhuotao's picture
Fix bug in loading weights for visual_model and text_hidden_fcs when using cached directory
c39e06d
raw
history blame
9.29 kB
import argparse
import os
import sys
import cv2
import glob
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from transformers import AutoTokenizer, CLIPImageProcessor
from model.LISA import LISA
from model.segment_anything.utils.transforms import ResizeLongestSide
from utils.conversation import get_default_conv_template
def parse_args(args):
parser = argparse.ArgumentParser(description="LISA chat")
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v0")
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image-size", default=1024, type=int, help="image size")
parser.add_argument("--model-max-length", default=512, type=int)
parser.add_argument("--lora-r", default=-1, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
return parser.parse_args(args)
def preprocess(
x,
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
img_size=1024,
) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - pixel_mean) / pixel_std
# Pad
h, w = x.shape[-2:]
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def main(args):
args = parse_args(args)
os.makedirs(args.vis_save_path, exist_ok=True)
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids
args.seg_token_idx = ret_token_idx[0]
model = LISA(
args.local_rank,
args.seg_token_idx,
tokenizer,
args.version,
args.lora_r,
args.precision,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit,
)
if os.path.exists(args.version):
model_dir = args.version
else: # hack for cached pre-trained weights
user_name, model_name = args.version.split("/")
cache_dir = "{}/.cache/huggingface/hub/models--{}--{}".format(os.environ['HOME'], user_name, model_name)
if os.path.exists(cache_dir):
model1_dir = glob.glob("{}/snapshots/*/pytorch_model-visual_model.bin".format(cache_dir))
model2_dir = glob.glob("{}/snapshots/*/pytorch_model-text_hidden_fcs.bin".format(cache_dir))
if len(model1_dir) == 0 or len(model2_dir) == 0:
raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
cache_dir
))
model1_dir = ["/".join(x.split("/")[:-1]) for x in model1_dir]
model2_dir = ["/".join(x.split("/")[:-1]) for x in model2_dir]
model_dir = list(set(model1_dir).intersection(set(model2_dir)))
if len(model_dir) == 0:
raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
cache_dir
))
model_dir = model_dir[0]
else:
raise ValueError("The path {} does not exists.".format(cache_dir))
weight = {}
visual_model_weight = torch.load(
os.path.join(model_dir, "pytorch_model-visual_model.bin")
)
text_hidden_fcs_weight = torch.load(
os.path.join(model_dir, "pytorch_model-text_hidden_fcs.bin")
)
weight.update(visual_model_weight)
weight.update(text_hidden_fcs_weight)
missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False)
if args.precision == "bf16":
model = model.bfloat16().cuda()
elif args.precision == "fp16":
import deepspeed
model_engine = deepspeed.init_inference(
model=model,
dtype=torch.half,
replace_with_kernel_inject=True,
replace_method="auto",
)
model = model_engine.module
else:
model = model.float().cuda()
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
image_token_len = 256
clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower)
transform = ResizeLongestSide(args.image_size)
while True:
conv = get_default_conv_template("vicuna").copy()
conv.messages = []
prompt = input("Please input your prompt: ")
prompt = DEFAULT_IMAGE_TOKEN + " " + prompt
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
image_path = input("Please input the image path: ")
if not os.path.exists(image_path):
print("File not found in {}".format(image_path))
continue
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_size_list = [image.shape[:2]]
if args.precision == "bf16":
images_clip = (
clip_image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
.unsqueeze(0)
.cuda()
.bfloat16()
)
elif args.precision == "fp16":
images_clip = (
clip_image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
.unsqueeze(0)
.cuda()
.half()
)
else:
images_clip = (
clip_image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
.unsqueeze(0)
.cuda()
.float()
)
images = transform.apply_image(image)
resize_list = [images.shape[:2]]
if args.precision == "bf16":
images = (
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
.unsqueeze(0)
.cuda()
.bfloat16()
)
elif args.precision == "fp16":
images = (
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
.unsqueeze(0)
.cuda()
.half()
)
else:
images = (
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
.unsqueeze(0)
.cuda()
.float()
)
input_ids = tokenizer(prompt).input_ids
input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda()
output_ids, pred_masks = model.evaluate(
images_clip,
images,
input_ids,
resize_list,
original_size_list,
max_new_tokens=512,
tokenizer=tokenizer,
)
text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
text_output = (
text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "")
.replace("\n", "")
.replace(" ", "")
)
print("text_output: ", text_output)
for i, pred_mask in enumerate(pred_masks):
if pred_mask.shape[0] == 0:
continue
pred_mask = pred_mask.detach().cpu().numpy()[0]
pred_mask = pred_mask > 0
save_path = "{}/{}_mask_{}.jpg".format(
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
)
cv2.imwrite(save_path, pred_mask * 100)
print("{} has been saved.".format(save_path))
save_path = "{}/{}_masked_img_{}.jpg".format(
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
)
save_img = image.copy()
save_img[pred_mask] = (
image * 0.5
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
)[pred_mask]
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(save_path, save_img)
print("{} has been saved.".format(save_path))
if __name__ == "__main__":
main(sys.argv[1:])