Spaces:
Paused
Paused
import argparse | |
import logging | |
import os | |
import re | |
import sys | |
from typing import Callable | |
import cv2 | |
import gradio as gr | |
import nh3 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor | |
from model.LISA import LISAForCausalLM | |
from model.llava import conversation as conversation_lib | |
from model.llava.mm_utils import tokenizer_image_token | |
from model.segment_anything.utils.transforms import ResizeLongestSide | |
from utils import constants, session_logger, utils | |
session_logger.change_logging(logging.DEBUG) | |
CUSTOM_GRADIO_PATH = "/" | |
app = FastAPI(title="lisa_app", version="1.0") | |
FASTAPI_STATIC = os.getenv("FASTAPI_STATIC") | |
os.makedirs(FASTAPI_STATIC, exist_ok=True) | |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static") | |
templates = Jinja2Templates(directory="templates") | |
placeholders = utils.create_placeholder_variables() | |
def health() -> str: | |
import json | |
try: | |
logging.info("health check") | |
return json.dumps({"msg": "ok"}) | |
except Exception as e: | |
logging.error(f"exception:{e}.") | |
return json.dumps({"msg": "request failed"}) | |
def parse_args(args_to_parse): | |
parser = argparse.ArgumentParser(description="LISA chat") | |
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory") | |
parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
parser.add_argument( | |
"--precision", | |
default="fp16", | |
type=str, | |
choices=["fp32", "bf16", "fp16"], | |
help="precision for inference", | |
) | |
parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
parser.add_argument("--model_max_length", default=512, type=int) | |
parser.add_argument("--lora_r", default=8, type=int) | |
parser.add_argument( | |
"--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
) | |
parser.add_argument("--local-rank", default=0, type=int, help="node rank") | |
parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
parser.add_argument("--load_in_4bit", action="store_true", default=True) | |
parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
parser.add_argument( | |
"--conv_type", | |
default="llava_v1", | |
type=str, | |
choices=["llava_v1", "llava_llama_2"], | |
) | |
return parser.parse_args(args_to_parse) | |
def get_cleaned_input(input_str): | |
logging.info(f"start cleaning of input_str: {input_str}.") | |
input_str = nh3.clean( | |
input_str, | |
tags={ | |
"a", | |
"abbr", | |
"acronym", | |
"b", | |
"blockquote", | |
"code", | |
"em", | |
"i", | |
"li", | |
"ol", | |
"strong", | |
"ul", | |
}, | |
attributes={ | |
"a": {"href", "title"}, | |
"abbr": {"title"}, | |
"acronym": {"title"}, | |
}, | |
url_schemes={"http", "https", "mailto"}, | |
link_rel=None, | |
) | |
logging.info(f"cleaned input_str: {input_str}.") | |
return input_str | |
def set_image_precision_by_args(input_image, precision): | |
if precision == "bf16": | |
input_image = input_image.bfloat16() | |
elif precision == "fp16": | |
input_image = input_image.half() | |
else: | |
input_image = input_image.float() | |
return input_image | |
def preprocess( | |
x, | |
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), | |
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), | |
img_size=1024, | |
) -> torch.Tensor: | |
"""Normalize pixel values and pad to a square input.""" | |
logging.info("preprocess started") | |
# Normalize colors | |
x = (x - pixel_mean) / pixel_std | |
# Pad | |
h, w = x.shape[-2:] | |
padh = img_size - h | |
padw = img_size - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
logging.info("preprocess ended") | |
return x | |
def get_model(args_to_parse): | |
logging.info("starting model preparation...") | |
os.makedirs(args_to_parse.vis_save_path, exist_ok=True) | |
# global tokenizer, tokenizer | |
# Create model | |
_tokenizer = AutoTokenizer.from_pretrained( | |
args_to_parse.version, | |
cache_dir=None, | |
model_max_length=args_to_parse.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
_tokenizer.pad_token = _tokenizer.unk_token | |
args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
torch_dtype = torch.float32 | |
if args_to_parse.precision == "bf16": | |
torch_dtype = torch.bfloat16 | |
elif args_to_parse.precision == "fp16": | |
torch_dtype = torch.half | |
kwargs = {"torch_dtype": torch_dtype} | |
if args_to_parse.load_in_4bit: | |
kwargs.update( | |
{ | |
"torch_dtype": torch.half, | |
"load_in_4bit": True, | |
"quantization_config": BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
llm_int8_skip_modules=["visual_model"], | |
), | |
} | |
) | |
elif args_to_parse.load_in_8bit: | |
kwargs.update( | |
{ | |
"torch_dtype": torch.half, | |
"quantization_config": BitsAndBytesConfig( | |
llm_int8_skip_modules=["visual_model"], | |
load_in_8bit=True, | |
), | |
} | |
) | |
_model = LISAForCausalLM.from_pretrained( | |
args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs | |
) | |
_model.config.eos_token_id = _tokenizer.eos_token_id | |
_model.config.bos_token_id = _tokenizer.bos_token_id | |
_model.config.pad_token_id = _tokenizer.pad_token_id | |
_model.get_model().initialize_vision_modules(_model.get_model().config) | |
vision_tower = _model.get_model().get_vision_tower() | |
vision_tower.to(dtype=torch_dtype) | |
if args_to_parse.precision == "bf16": | |
_model = _model.bfloat16().cuda() | |
elif ( | |
args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit) | |
): | |
vision_tower = _model.get_model().get_vision_tower() | |
_model.model.vision_tower = None | |
import deepspeed | |
model_engine = deepspeed.init_inference( | |
model=_model, | |
dtype=torch.half, | |
replace_with_kernel_inject=True, | |
replace_method="auto", | |
) | |
_model = model_engine.module | |
_model.model.vision_tower = vision_tower.half().cuda() | |
elif args_to_parse.precision == "fp32": | |
_model = _model.float().cuda() | |
vision_tower = _model.get_model().get_vision_tower() | |
vision_tower.to(device=args_to_parse.local_rank) | |
_clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower) | |
_transform = ResizeLongestSide(args_to_parse.image_size) | |
_model.eval() | |
logging.info("model preparation ok!") | |
return _model, _clip_image_processor, _tokenizer, _transform | |
def get_inference_model_by_args(args_to_parse): | |
logging.info(f"args_to_parse:{args_to_parse}, creating model...") | |
model, clip_image_processor, tokenizer, transform = get_model(args_to_parse) | |
logging.info("created model, preparing inference function") | |
no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"] | |
def inference(input_str, input_image): | |
## filter out special chars | |
input_str = get_cleaned_input(input_str) | |
logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.") | |
logging.info(f"input_str: {input_str}.") | |
## input valid check | |
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1: | |
output_str = "[Error] Invalid input: ", input_str | |
return error_happened, output_str | |
# Model Inference | |
conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy() | |
conv.messages = [] | |
prompt = input_str | |
prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
if args_to_parse.use_mm_start_end: | |
replace_token = ( | |
utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN | |
) | |
prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token) | |
conv.append_message(conv.roles[0], prompt) | |
conv.append_message(conv.roles[1], "") | |
prompt = conv.get_prompt() | |
image_np = cv2.imread(input_image) | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
original_size_list = [image_np.shape[:2]] | |
image_clip = ( | |
clip_image_processor.preprocess(image_np, return_tensors="pt")[ | |
"pixel_values" | |
][0] | |
.unsqueeze(0) | |
.cuda() | |
) | |
logging.info(f"image_clip type: {type(image_clip)}.") | |
image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision) | |
image = transform.apply_image(image_np) | |
resize_list = [image.shape[:2]] | |
image = ( | |
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) | |
.unsqueeze(0) | |
.cuda() | |
) | |
logging.info(f"image_clip type: {type(image_clip)}.") | |
image = set_image_precision_by_args(image, args_to_parse.precision) | |
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
input_ids = input_ids.unsqueeze(0).cuda() | |
output_ids, pred_masks = model.evaluate( | |
image_clip, | |
image, | |
input_ids, | |
resize_list, | |
original_size_list, | |
max_new_tokens=512, | |
tokenizer=tokenizer, | |
) | |
output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX] | |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False) | |
text_output = text_output.replace("\n", "").replace(" ", " ") | |
text_output = text_output.split("ASSISTANT: ")[-1] | |
logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.") | |
save_img = None | |
for i, pred_mask in enumerate(pred_masks): | |
if pred_mask.shape[0] == 0: | |
continue | |
pred_mask = pred_mask.detach().cpu().numpy()[0] | |
pred_mask = pred_mask > 0 | |
save_img = image_np.copy() | |
save_img[pred_mask] = ( | |
image_np * 0.5 | |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 | |
)[pred_mask] | |
output_str = f"ASSISTANT: {text_output}" | |
output_image = no_seg_out if save_img is None else save_img | |
logging.info(f"output_image type: {type(output_image)}.") | |
return output_image, output_str | |
logging.info("prepared inference function!") | |
return inference | |
def get_gradio_interface( | |
fn_inference: Callable | |
): | |
return gr.Interface( | |
fn_inference, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), | |
gr.Image(type="filepath", label="Input Image") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation Output"), | |
gr.Textbox(lines=1, placeholder=None, label="Text Output") | |
], | |
title=constants.title, | |
description=constants.description, | |
article=constants.article, | |
examples=constants.examples, | |
allow_flagging="auto" | |
) | |
logging.info(f"sys.argv:{sys.argv}.") | |
args = parse_args([]) | |
logging.info(f"prepared default arguments:{args}.") | |
inference_fn = get_inference_model_by_args(args) | |
logging.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...") | |
io = get_gradio_interface(inference_fn) | |
logging.info("created gradio interface") | |
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH) | |
logging.info("mounted gradio app within fastapi") | |