import torch from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from conversation import conv_templates, SeparatorStyle from builder import load_pretrained_model from utils import disable_torch_init from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from PIL import Image import requests from io import BytesIO from transformers import TextStreamer import spaces from functools import partial import traceback import sys # def load_image(image_file): # if image_file.startswith('http://') or image_file.startswith('https://'): # response = requests.get(image_file) # image = Image.open(BytesIO(response.content)).convert('RGB') # else: # image = Image.open(image_file).convert('RGB') # return image def load_image(image_file): print("the image file : ", image_file) image = Image.open(image_file).convert('RGB') if image is None: print("image is None") sys.exit("Aborting program: Image is None.") return image @spaces.GPU() def run_inference( model_path, image_file, prompt_text, model_base=None, device="cuda", conv_mode=None, temperature=0.2, max_new_tokens=512, load_8bit=False, load_4bit=False, debug=False ): # Model initialization disable_torch_init() model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path, model_base, model_name, load_8bit, load_4bit ) # Determine conversation mode if "llama-2" in model_name.lower(): conv_mode_inferred = "llava_llama_2" elif "mistral" in model_name.lower(): conv_mode_inferred = "mistral_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode_inferred = "chatml_direct" elif "v1" in model_name.lower(): conv_mode_inferred = "llava_v1" elif "mpt" in model_name.lower(): conv_mode_inferred = "mpt" elif "gemma" in model_name.lower(): conv_mode_inferred = "ferret_gemma_instruct" elif "llama" in model_name.lower(): conv_mode_inferred = "ferret_llama_3" else: conv_mode_inferred = "llava_v0" # Use user-specified conversation mode if provided conv_mode = conv_mode or conv_mode_inferred if conv_mode != conv_mode_inferred: print(f'[WARNING] the auto inferred conversation mode is {conv_mode_inferred}, while `conv_mode` is {conv_mode}, using {conv_mode}') conv = conv_templates[conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles # Load and process image print("loading image", image_file) image = load_image(image_file) if image is None: print("image is None") image_size = image.size image_h = 336 # Height of the image image_w = 336 #ERROR # image_tensor = process_images([image], image_processor, model.config) # if type(image_tensor) is list: # image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] # else: # image_tensor = image_tensor.to(model.device, dtype=torch.float16) if model.config.image_aspect_ratio == "square_nocrop": image_tensor = image_processor.preprocess(image, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])['pixel_values'][0] elif model.config.image_aspect_ratio == "anyres": image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w]) image_tensor = process_images([image], image_processor, model.config, image_process_func=image_process_func)[0] else: image_tensor = process_images([image], image_processor, model.config)[0] if model.dtype == torch.float16: image_tensor = image_tensor.half() # Convert image tensor to float16 data_type = torch.float16 else: image_tensor = image_tensor.float() # Keep it in float32 data_type = torch.float32 # Now, add the batch dimension and move to GPU images = image_tensor.unsqueeze(0).to(data_type).cuda() # Process the first message with the image if model.config.mm_use_im_start_end: prompt_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text else: prompt_text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text # Prepare conversation conv.append_message(conv.roles[0], prompt_text) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) print("image size: ", image_size) # Generate the model's response with torch.inference_mode(): output_ids = model.generate( input_ids, images=images, image_sizes=[image_size], do_sample=True if temperature > 0 else False, temperature=temperature, max_new_tokens=max_new_tokens, streamer=streamer, num_beams=1, use_cache=True ) # Decode and return the output outputs = tokenizer.decode(output_ids[0]).strip() conv.messages[-1][-1] = outputs if debug: print("\n", {"prompt": prompt, "outputs": outputs}, "\n") return outputs # Example usage: # response = run_inference("path_to_model", "path_to_image", "your_prompt") # print(response)