import os import base64 import io import sqlite3 import torch import gradio as gr import pandas as pd from PIL import Image import requests from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer from huggingface_hub import hf_hub_download from datasets import load_dataset import traceback from tqdm import tqdm import zipfile # Define constants for vikhyatk/moondream2 model MOON_DREAM_MODEL_ID = "vikhyatk/moondream2" MOON_DREAM_REVISION = "2024-08-26" # Define constants for the Qwen2-VL models QWEN2_VL_MODELS = [ 'Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', 'OpenGVLab/InternVL2-1B', 'Qwen/Qwen2-VL-72B', ] # List of models to use (combining unique entries from available models and QWEN2_VL_MODELS) available_models = [ *QWEN2_VL_MODELS, # Expands the QWEN2_VL_MODELS list into the available_models 'microsoft/Phi-3-vision-128k-instruct', 'vikhyatk/moondream2' ] # List of available Hugging Face datasets dataset_options = [ "gokaygokay/panorama_hdr_dataset", "OpenGVLab/CRPE" ] # List of text prompts to use text_prompts = [ "Provide a detailed description of the image contents, including all visible objects, people, activities, and extract any text present within the image using Optical Character Recognition (OCR). Organize the extracted text in a structured table format with columns for original text, its translation into English, and the language it is written in.", "Offer a thorough description of all elements within the image, from objects to individuals and their activities. Ensure any legible text seen in the image is extracted using Optical Character Recognition (OCR). Provide an accurate narrative that encapsulates the full content of the image.", "Create a four-sentence caption for the image. Start by specifying the style and type, such as painting, photograph, or digital art. In the next sentences, detail the contents and the composition clearly and concisely. Use language suited for prompting a text-to-image model, separating descriptive terms with commas instead of 'or'. Keep the description direct, avoiding interpretive phrases or abstract expressions", ] # SQLite setup # def init_db(): # conn = sqlite3.connect('image_outputs.db') # cursor = conn.cursor() # cursor.execute(''' # CREATE TABLE IF NOT EXISTS image_outputs ( # id INTEGER PRIMARY KEY AUTOINCREMENT, # image BLOB, # prompt TEXT, # output TEXT, # model_name TEXT # ) # ''') # conn.commit() # conn.close() def image_to_binary(image_path): with open(image_path, 'rb') as file: return file.read() # def store_in_db(image_path, prompt, output, model_name): # conn = sqlite3.connect('image_outputs.db') # cursor = conn.cursor() # image_blob = image_to_binary(image_path) # cursor.execute(''' # INSERT INTO image_outputs (image, prompt, output, model_name) # VALUES (?, ?, ?, ?) # ''', (image_blob, prompt, output, model_name)) # conn.commit() # conn.close() # Function to encode an image to base64 for HTML display def encode_image(image): img_buffer = io.BytesIO() image.save(img_buffer, format="PNG") img_str = base64.b64encode(img_buffer.getvalue()).decode("utf-8") return f'' # Function to load and display images from the panorama_hdr_dataset def load_dataset_images(dataset_name, num_images): try: dataset = load_dataset(dataset_name, split='train') images = [] for i, item in enumerate(dataset[:num_images]): if 'image' in item: img = item['image'] print (type(img)) encoded_img = encode_image(img) metadata = f"Width: {img.width}, Height: {img.height}" if 'hdr' in item: metadata += f", HDR: {item['hdr']}" images.append(f"

Image {i+1}

{encoded_img}

{metadata}

") if not images: return "No images could be loaded from this dataset. Please check the dataset structure." return "".join(images) except Exception as e: print(f"Error loading dataset: {e}") traceback.print_exc() # Function to generate output def generate_output(model, processor, prompt, image, model_name, device): try: image_bytes = io.BytesIO() image.save(image_bytes, format="PNG") image_bytes = image_bytes.getvalue() if model_name in QWEN2_VL_MODELS: messages = [ { "role": "user", "content": [ {"type": "image", "image": image_bytes}, {"type": "text", "text": prompt}, ] } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[Image.open(io.BytesIO(image_bytes))], padding=True, return_tensors="pt", ) inputs = {k: v.to(device) for k, v in inputs.items()} generated_ids = model.generate(**inputs, max_new_tokens=1024) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)] response_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response_text elif model_name == 'microsoft/Phi-3-vision-128k-instruct': messages = [{"role": "user", "content": f"<|image_1|>\n{prompt}"}] prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(prompt, [image], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=1024) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)[0] return response_text elif model_name == 'vikhyatk/moondream2': tokenizer = AutoTokenizer.from_pretrained(MOON_DREAM_MODEL_ID, revision=MOON_DREAM_REVISION) enc_image = model.encode_image(image) response_text = model.answer_question(enc_image, prompt, tokenizer) return response_text except Exception as e: return f"Error during generation with model {model_name}: {e}" # Function to list and encode images from a directory def list_images(directory_path): images = [] for filename in os.listdir(directory_path): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): image_path = os.path.join(directory_path, filename) encoded_img = encode_image(image_path) images.append({ "filename": filename, "image": encoded_img }) return images # Function to extract images from a ZIP file # Function to extract images from a ZIP file def extract_images_from_zip(zip_file): images = [] with zipfile.ZipFile(zip_file, 'r') as zip_ref: for file_info in zip_ref.infolist(): if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg')): with zip_ref.open(file_info) as file: try: img = Image.open(file) img = img.convert("RGB") # Ensure the image is in RGB mode encoded_img = img images.append({ "filename": file_info.filename, "image": encoded_img }) except Exception as e: print(f"Error opening image {file_info.filename}: {e}") return images # Gradio interface function for running inference def run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file): data = [] torch_dtype_value = torch.float16 if torch_dtype == "torch.float16" else torch.float32 device_map_value = "cuda" if torch.cuda.is_available() else "cpu" if device_map == "auto" else device_map model_processors = {} for model_name in model_names: try: if model_name in QWEN2_VL_MODELS: model = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype_value, device_map=device_map_value ).eval() processor = AutoProcessor.from_pretrained(model_name) elif model_name == 'microsoft/Phi-3-vision-128k-instruct': model = AutoModelForCausalLM.from_pretrained( model_name, device_map=device_map_value, torch_dtype=torch_dtype_value, trust_remote_code=trust_remote_code, use_flash_attn=use_flash_attn ).eval() processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code) elif model_name == 'vikhyatk/moondream2': model = AutoModelForCausalLM.from_pretrained( MOON_DREAM_MODEL_ID, trust_remote_code=True, revision=MOON_DREAM_REVISION ).eval() processor = None # No processor needed for this model model_processors[model_name] = (model, processor) except Exception as e: print(f"Error loading model {model_name}: {e}") try: # Load images from the ZIP file if use_zip is True if use_zip: images = extract_images_from_zip(zip_file) print ("Number of images in zip:" , len(images)) for img in tqdm(images): try: img_data = img['image'] if not isinstance(img_data, str): # Convert the Image object to a base64-encoded string img_buffer = io.BytesIO() img['image'].save(img_buffer, format="PNG") img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8") img_data=f'' row_data = {"Image": img_data} # Assuming encode_image is defined elsewhere for model_name in model_names: if model_name in model_processors: model, processor = model_processors[model_name] for prompt in prompts: try: # Ensure image is defined image = img['image'] response_text = generate_output(model, processor, prompt, image, model_name, device_map_value) row_data[f"{model_name}_Response_{prompt}"] = response_text except Exception as e: row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}" traceback.print_exc() data.append(row_data) except Exception as e: print(f"Error processing image {img['filename']}: {e}") traceback.print_exc() # Load the dataset if use_zip is False else: dataset = load_dataset(dataset_input, split='train') for i in tqdm(range(num_images_input)): if dataset_input == "OpenGVLab/CRPE": image = dataset[i]['image'] elif dataset_input == "gokaygokay/panorama_hdr_dataset": image = dataset[i]['png_image'] else: image = dataset[i]['image'] encoded_img = encode_image(image) row_data = {"Image": encoded_img} for model_name in model_names: if model_name in model_processors: model, processor = model_processors[model_name] for prompt in prompts: try: response_text = generate_output(model, processor, prompt, image, model_name, device_map_value) row_data[f"{model_name}_Response_{prompt}"] = response_text except Exception as e: row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}" data.append(row_data) except Exception as e: print(f"Error loading dataset: {e}") traceback.print_exc() return pd.DataFrame(data).to_html(escape=False) # Gradio UI setup def create_gradio_interface(): css = """ #output { height: 500px; overflow: auto; } """ with gr.Blocks(css=css) as demo: # Title gr.Markdown("# VLM-Image-Analysis: A Vision-and-Language Modeling Framework.") gr.Markdown(""" - Handle a batch of images from a ZIP file OR - Processes images from an HF DB - Compatible with png, jpg, jpeg, and webp formats - Compatibility with various AI models: Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct-GPTQ-Int4, InternVL2-1B, Qwen2-VL-72B, /Phi-3-vision-128k-instruct and moondream2""") # image_path = os.path.abspath("static/image.jpg") # gr.Image(value=image_path, label="HF Image", width=300, height=300) with gr.Tab("VLM model and Dataset selection"): gr.Markdown("### Dataset Selection: HF or from a ZIP file.") with gr.Accordion("Advanced Settings", open=True): with gr.Row(): # with gr.Column(): use_zip_input = gr.Checkbox(label="Use ZIP File", value=False) dataset_input = gr.Dropdown(choices=dataset_options, label="Select Dataset", value=dataset_options[1], visible=True) num_images_input = gr.Radio(choices=[1, 5, 20], label="Number of Images", value=5) zip_file_input = gr.File(label="Upload ZIP File of Images", file_types=[".zip"]) gr.Markdown("### VLM Model Selection") with gr.Row(): with gr.Column(): models_input = gr.CheckboxGroup(choices=available_models, label="Select Models", value=available_models[4]) prompts_input = gr.CheckboxGroup(choices=text_prompts, label="Select Prompts", value=text_prompts[2]) submit_btn = gr.Button("Run Inference") with gr.Row(): output_display = gr.HTML(label="Results") with gr.Tab("GPU Device Settings"): device_map_input = gr.Radio(choices=["auto", "cpu", "cuda"], label="Device Map", value="auto") torch_dtype_input = gr.Radio(choices=["torch.float16", "torch.float32"], label="Torch Dtype", value="torch.float16") trust_remote_code_input = gr.Checkbox(label="Trust Remote Code", value=True) use_flash_attn = gr.Checkbox(label="Use flash-attn 2 (Ampere GPUs or newer.)", value=False) def run_inference_wrapper(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file): return run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file) def toggle_dataset_visibility(use_zip): return gr.update(visible=not use_zip) submit_btn.click( fn=run_inference_wrapper, inputs=[models_input, dataset_input, num_images_input, prompts_input, device_map_input, torch_dtype_input, trust_remote_code_input,use_flash_attn, use_zip_input, zip_file_input], outputs=output_display ) use_zip_input.change( fn=toggle_dataset_visibility, inputs=use_zip_input, outputs=dataset_input ) demo.launch(debug=True, share=False) if __name__ == "__main__": create_gradio_interface()