Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from conversation import conv_templates, SeparatorStyle | |
from builder import load_pretrained_model | |
from utils import disable_torch_init | |
from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from transformers import TextStreamer | |
import spaces | |
from functools import partial | |
import traceback | |
import sys | |
# def load_image(image_file): | |
# if image_file.startswith('http://') or image_file.startswith('https://'): | |
# response = requests.get(image_file) | |
# image = Image.open(BytesIO(response.content)).convert('RGB') | |
# else: | |
# image = Image.open(image_file).convert('RGB') | |
# return image | |
def load_image(image_file): | |
print("the image file : ", image_file) | |
image = Image.open(image_file).convert('RGB') | |
if image is None: | |
print("image is None") | |
sys.exit("Aborting program: Image is None.") | |
return image | |
def run_inference( | |
model_path, | |
image_file, | |
prompt_text, | |
model_base=None, | |
device="cuda", | |
conv_mode=None, | |
temperature=0.2, | |
max_new_tokens=512, | |
load_8bit=False, | |
load_4bit=False, | |
debug=False | |
): | |
# Model initialization | |
disable_torch_init() | |
model_name = get_model_name_from_path(model_path) | |
tokenizer, model, image_processor, context_len = load_pretrained_model( | |
model_path, model_base, model_name, load_8bit, load_4bit | |
) | |
# Determine conversation mode | |
if "llama-2" in model_name.lower(): | |
conv_mode_inferred = "llava_llama_2" | |
elif "mistral" in model_name.lower(): | |
conv_mode_inferred = "mistral_instruct" | |
elif "v1.6-34b" in model_name.lower(): | |
conv_mode_inferred = "chatml_direct" | |
elif "v1" in model_name.lower(): | |
conv_mode_inferred = "llava_v1" | |
elif "mpt" in model_name.lower(): | |
conv_mode_inferred = "mpt" | |
elif "gemma" in model_name.lower(): | |
conv_mode_inferred = "ferret_gemma_instruct" | |
elif "llama" in model_name.lower(): | |
conv_mode_inferred = "ferret_llama_3" | |
else: | |
conv_mode_inferred = "llava_v0" | |
# Use user-specified conversation mode if provided | |
conv_mode = conv_mode or conv_mode_inferred | |
if conv_mode != conv_mode_inferred: | |
print(f'[WARNING] the auto inferred conversation mode is {conv_mode_inferred}, while `conv_mode` is {conv_mode}, using {conv_mode}') | |
conv = conv_templates[conv_mode].copy() | |
if "mpt" in model_name.lower(): | |
roles = ('user', 'assistant') | |
else: | |
roles = conv.roles | |
# Load and process image | |
print("loading image", image_file) | |
image = load_image(image_file) | |
if image is None: | |
print("image is None") | |
image_size = image.size | |
image_h = 336 # Height of the image | |
image_w = 336 | |
#ERROR | |
# image_tensor = process_images([image], image_processor, model.config) | |
# if type(image_tensor) is list: | |
# image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] | |
# else: | |
# image_tensor = image_tensor.to(model.device, dtype=torch.float16) | |
if model.config.image_aspect_ratio == "square_nocrop": | |
image_tensor = image_processor.preprocess(image, return_tensors='pt', do_resize=True, | |
do_center_crop=False, size=[image_h, image_w])['pixel_values'][0] | |
elif model.config.image_aspect_ratio == "anyres": | |
image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w]) | |
image_tensor = process_images([image], image_processor, model.config, image_process_func=image_process_func)[0] | |
else: | |
image_tensor = process_images([image], image_processor, model.config)[0] | |
if model.dtype == torch.float16: | |
image_tensor = image_tensor.half() # Convert image tensor to float16 | |
data_type = torch.float16 | |
else: | |
image_tensor = image_tensor.float() # Keep it in float32 | |
data_type = torch.float32 | |
# Now, add the batch dimension and move to GPU | |
images = image_tensor.unsqueeze(0).to(data_type).cuda() | |
# Process the first message with the image | |
if model.config.mm_use_im_start_end: | |
prompt_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text | |
else: | |
prompt_text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text | |
# Prepare conversation | |
conv.append_message(conv.roles[0], prompt_text) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
print("image size: ", image_size) | |
# Generate the model's response | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=images, | |
image_sizes=[image_size], | |
do_sample=True if temperature > 0 else False, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
num_beams=1, | |
use_cache=True | |
) | |
# Decode and return the output | |
outputs = tokenizer.decode(output_ids[0]).strip() | |
conv.messages[-1][-1] = outputs | |
if debug: | |
print("\n", {"prompt": prompt, "outputs": outputs}, "\n") | |
return outputs | |
# Example usage: | |
# response = run_inference("path_to_model", "path_to_image", "your_prompt") | |
# print(response) | |