ferret-demo / cli.py
jadechoghari's picture
working app
151137d
raw
history blame
5.87 kB
import torch
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from conversation import conv_templates, SeparatorStyle
from builder import load_pretrained_model
from utils import disable_torch_init
from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from PIL import Image
import requests
from io import BytesIO
from transformers import TextStreamer
import spaces
from functools import partial
import traceback
import sys
# def load_image(image_file):
# if image_file.startswith('http://') or image_file.startswith('https://'):
# response = requests.get(image_file)
# image = Image.open(BytesIO(response.content)).convert('RGB')
# else:
# image = Image.open(image_file).convert('RGB')
# return image
def load_image(image_file):
print("the image file : ", image_file)
image = Image.open(image_file).convert('RGB')
if image is None:
print("image is None")
sys.exit("Aborting program: Image is None.")
return image
@spaces.GPU()
def run_inference(
model_path,
image_file,
prompt_text,
model_base=None,
device="cuda",
conv_mode=None,
temperature=0.2,
max_new_tokens=512,
load_8bit=False,
load_4bit=False,
debug=False
):
# Model initialization
disable_torch_init()
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_base, model_name, load_8bit, load_4bit
)
# Determine conversation mode
if "llama-2" in model_name.lower():
conv_mode_inferred = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode_inferred = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode_inferred = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode_inferred = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode_inferred = "mpt"
elif "gemma" in model_name.lower():
conv_mode_inferred = "ferret_gemma_instruct"
elif "llama" in model_name.lower():
conv_mode_inferred = "ferret_llama_3"
else:
conv_mode_inferred = "llava_v0"
# Use user-specified conversation mode if provided
conv_mode = conv_mode or conv_mode_inferred
if conv_mode != conv_mode_inferred:
print(f'[WARNING] the auto inferred conversation mode is {conv_mode_inferred}, while `conv_mode` is {conv_mode}, using {conv_mode}')
conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
# Load and process image
print("loading image", image_file)
image = load_image(image_file)
if image is None:
print("image is None")
image_size = image.size
image_h = 336 # Height of the image
image_w = 336
#ERROR
# image_tensor = process_images([image], image_processor, model.config)
# if type(image_tensor) is list:
# image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
# else:
# image_tensor = image_tensor.to(model.device, dtype=torch.float16)
if model.config.image_aspect_ratio == "square_nocrop":
image_tensor = image_processor.preprocess(image, return_tensors='pt', do_resize=True,
do_center_crop=False, size=[image_h, image_w])['pixel_values'][0]
elif model.config.image_aspect_ratio == "anyres":
image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])
image_tensor = process_images([image], image_processor, model.config, image_process_func=image_process_func)[0]
else:
image_tensor = process_images([image], image_processor, model.config)[0]
if model.dtype == torch.float16:
image_tensor = image_tensor.half() # Convert image tensor to float16
data_type = torch.float16
else:
image_tensor = image_tensor.float() # Keep it in float32
data_type = torch.float32
# Now, add the batch dimension and move to GPU
images = image_tensor.unsqueeze(0).to(data_type).cuda()
# Process the first message with the image
if model.config.mm_use_im_start_end:
prompt_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text
else:
prompt_text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
# Prepare conversation
conv.append_message(conv.roles[0], prompt_text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
print("image size: ", image_size)
# Generate the model's response
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
num_beams=1,
use_cache=True
)
# Decode and return the output
outputs = tokenizer.decode(output_ids[0]).strip()
conv.messages[-1][-1] = outputs
if debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
return outputs
# Example usage:
# response = run_inference("path_to_model", "path_to_image", "your_prompt")
# print(response)