import argparse import logging import os import re import sys from typing import Callable import cv2 import gradio as gr import nh3 import numpy as np import torch import torch.nn.functional as F from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor from model.LISA import LISAForCausalLM from model.llava import conversation as conversation_lib from model.llava.mm_utils import tokenizer_image_token from model.segment_anything.utils.transforms import ResizeLongestSide from utils import constants, session_logger, utils session_logger.change_logging(logging.DEBUG) CUSTOM_GRADIO_PATH = "/" app = FastAPI(title="lisa_app", version="1.0") FASTAPI_STATIC = os.getenv("FASTAPI_STATIC") os.makedirs(FASTAPI_STATIC, exist_ok=True) app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static") templates = Jinja2Templates(directory="templates") placeholders = utils.create_placeholder_variables() @app.get("/health") @session_logger.set_uuid_logging def health() -> str: import json try: logging.info("health check") return json.dumps({"msg": "ok"}) except Exception as e: logging.error(f"exception:{e}.") return json.dumps({"msg": "request failed"}) @session_logger.set_uuid_logging def parse_args(args_to_parse): parser = argparse.ArgumentParser(description="LISA chat") parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory") parser.add_argument("--vis_save_path", default="./vis_output", type=str) parser.add_argument( "--precision", default="fp16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument( "--vision-tower", default="openai/clip-vit-large-patch14", type=str ) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=True) parser.add_argument("--use_mm_start_end", action="store_true", default=True) parser.add_argument( "--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"], ) return parser.parse_args(args_to_parse) @session_logger.set_uuid_logging def get_cleaned_input(input_str): logging.info(f"start cleaning of input_str: {input_str}.") input_str = nh3.clean( input_str, tags={ "a", "abbr", "acronym", "b", "blockquote", "code", "em", "i", "li", "ol", "strong", "ul", }, attributes={ "a": {"href", "title"}, "abbr": {"title"}, "acronym": {"title"}, }, url_schemes={"http", "https", "mailto"}, link_rel=None, ) logging.info(f"cleaned input_str: {input_str}.") return input_str @session_logger.set_uuid_logging def set_image_precision_by_args(input_image, precision): if precision == "bf16": input_image = input_image.bfloat16() elif precision == "fp16": input_image = input_image.half() else: input_image = input_image.float() return input_image @session_logger.set_uuid_logging def preprocess( x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, ) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" logging.info("preprocess started") # Normalize colors x = (x - pixel_mean) / pixel_std # Pad h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) logging.info("preprocess ended") return x @session_logger.set_uuid_logging def get_model(args_to_parse): logging.info("starting model preparation...") os.makedirs(args_to_parse.vis_save_path, exist_ok=True) # global tokenizer, tokenizer # Create model _tokenizer = AutoTokenizer.from_pretrained( args_to_parse.version, cache_dir=None, model_max_length=args_to_parse.model_max_length, padding_side="right", use_fast=False, ) _tokenizer.pad_token = _tokenizer.unk_token args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0] torch_dtype = torch.float32 if args_to_parse.precision == "bf16": torch_dtype = torch.bfloat16 elif args_to_parse.precision == "fp16": torch_dtype = torch.half kwargs = {"torch_dtype": torch_dtype} if args_to_parse.load_in_4bit: kwargs.update( { "torch_dtype": torch.half, "load_in_4bit": True, "quantization_config": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["visual_model"], ), } ) elif args_to_parse.load_in_8bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), } ) _model = LISAForCausalLM.from_pretrained( args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs ) _model.config.eos_token_id = _tokenizer.eos_token_id _model.config.bos_token_id = _tokenizer.bos_token_id _model.config.pad_token_id = _tokenizer.pad_token_id _model.get_model().initialize_vision_modules(_model.get_model().config) vision_tower = _model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) if args_to_parse.precision == "bf16": _model = _model.bfloat16().cuda() elif ( args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit) ): vision_tower = _model.get_model().get_vision_tower() _model.model.vision_tower = None import deepspeed model_engine = deepspeed.init_inference( model=_model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) _model = model_engine.module _model.model.vision_tower = vision_tower.half().cuda() elif args_to_parse.precision == "fp32": _model = _model.float().cuda() vision_tower = _model.get_model().get_vision_tower() vision_tower.to(device=args_to_parse.local_rank) _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower) _transform = ResizeLongestSide(args_to_parse.image_size) _model.eval() logging.info("model preparation ok!") return _model, _clip_image_processor, _tokenizer, _transform @session_logger.set_uuid_logging def get_inference_model_by_args(args_to_parse): logging.info(f"args_to_parse:{args_to_parse}, creating model...") model, clip_image_processor, tokenizer, transform = get_model(args_to_parse) logging.info("created model, preparing inference function") no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"] @session_logger.set_uuid_logging def inference(input_str, input_image): ## filter out special chars input_str = get_cleaned_input(input_str) logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.") logging.info(f"input_str: {input_str}.") ## input valid check if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1: output_str = "[Error] Invalid input: ", input_str return error_happened, output_str # Model Inference conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy() conv.messages = [] prompt = input_str prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt if args_to_parse.use_mm_start_end: replace_token = ( utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() image_np = cv2.imread(input_image) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] image_clip = ( clip_image_processor.preprocess(image_np, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision) image = transform.apply_image(image_np) resize_list = [image.shape[:2]] image = ( preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image = set_image_precision_by_args(image, args_to_parse.precision) input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() output_ids, pred_masks = model.evaluate( image_clip, image, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer, ) output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX] text_output = tokenizer.decode(output_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") text_output = text_output.split("ASSISTANT: ")[-1] logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.") save_img = None for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 save_img = image_np.copy() save_img[pred_mask] = ( image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask] output_str = f"ASSISTANT: {text_output}" output_image = no_seg_out if save_img is None else save_img logging.info(f"output_image type: {type(output_image)}.") return output_image, output_str logging.info("prepared inference function!") return inference @session_logger.set_uuid_logging def get_gradio_interface( fn_inference: Callable ): return gr.Interface( fn_inference, inputs=[ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), gr.Image(type="filepath", label="Input Image") ], outputs=[ gr.Image(type="pil", label="Segmentation Output"), gr.Textbox(lines=1, placeholder=None, label="Text Output") ], title=constants.title, description=constants.description, article=constants.article, examples=constants.examples, allow_flagging="auto" ) logging.info(f"sys.argv:{sys.argv}.") args = parse_args([]) logging.info(f"prepared default arguments:{args}.") inference_fn = get_inference_model_by_args(args) logging.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...") io = get_gradio_interface(inference_fn) logging.info("created gradio interface") app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH) logging.info("mounted gradio app within fastapi")