import gradio as gr from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from llava.conversation import SeparatorStyle, conv_templates from llava.mm_utils import ( KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token, ) from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from PIL import Image import torch # Disable PyTorch initialization disable_torch_init() # Load the pretrained model MODEL = "4bit/llava-v1.5-13b-3GB" model_name = get_model_name_from_path(MODEL) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True ) # Define the prompt creation function def create_prompt(prompt: str): conv = conv_templates["llava_v0"].copy() roles = conv.roles prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt conv.append_message(roles[0], prompt) conv.append_message(roles[1], None) return conv.get_prompt(), conv # Define the image processing function def process_image(image): args = {"image_aspect_ratio": "pad"} image_tensor = process_images([image], image_processor, args) return image_tensor.to(model.device, dtype=torch.float16) # Define the image description function def describe_image(image_file): image = Image.open(image_file) image.resize((500, 500)) processed_image = process_image(image) prompt, _ = create_prompt("Describe the image") input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(model.device) ) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 stopping_criteria = KeywordsStoppingCriteria( keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids ) with torch.inference_mode(): output_ids = model.generate( input_ids, images=processed_image, do_sample=True, temperature=0.01, max_new_tokens=512, use_cache=True, stopping_criteria=[stopping_criteria], ) description = tokenizer.decode( output_ids[0, input_ids.shape[1] :], skip_special_tokens=True ).strip() return description iface = gr.Interface( fn=describe_image, inputs=gr.Image(type="pil", label="Image"), # Specify the label for the input outputs=gr.Textbox(), live=True, capture_session=True ) # Launch the Gradio interface iface.launch()