import os
import base64
import io
import sqlite3
import torch
import gradio as gr
import pandas as pd
from PIL import Image
import requests
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import traceback
from tqdm import tqdm
import zipfile
# Define constants for vikhyatk/moondream2 model
MOON_DREAM_MODEL_ID = "vikhyatk/moondream2"
MOON_DREAM_REVISION = "2024-08-26"
# Define constants for the Qwen2-VL models
QWEN2_VL_MODELS = [
'Qwen/Qwen2-VL-7B-Instruct',
'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4',
'OpenGVLab/InternVL2-1B',
'Qwen/Qwen2-VL-72B',
]
# List of models to use (combining unique entries from available models and QWEN2_VL_MODELS)
available_models = [
*QWEN2_VL_MODELS, # Expands the QWEN2_VL_MODELS list into the available_models
'microsoft/Phi-3-vision-128k-instruct',
'vikhyatk/moondream2'
]
# List of available Hugging Face datasets
dataset_options = [
"gokaygokay/panorama_hdr_dataset",
"OpenGVLab/CRPE"
]
# List of text prompts to use
text_prompts = [
"Provide a detailed description of the image contents, including all visible objects, people, activities, and extract any text present within the image using Optical Character Recognition (OCR). Organize the extracted text in a structured table format with columns for original text, its translation into English, and the language it is written in.",
"Offer a thorough description of all elements within the image, from objects to individuals and their activities. Ensure any legible text seen in the image is extracted using Optical Character Recognition (OCR). Provide an accurate narrative that encapsulates the full content of the image.",
"Create a four-sentence caption for the image. Start by specifying the style and type, such as painting, photograph, or digital art. In the next sentences, detail the contents and the composition clearly and concisely. Use language suited for prompting a text-to-image model, separating descriptive terms with commas instead of 'or'. Keep the description direct, avoiding interpretive phrases or abstract expressions",
]
# SQLite setup
# def init_db():
# conn = sqlite3.connect('image_outputs.db')
# cursor = conn.cursor()
# cursor.execute('''
# CREATE TABLE IF NOT EXISTS image_outputs (
# id INTEGER PRIMARY KEY AUTOINCREMENT,
# image BLOB,
# prompt TEXT,
# output TEXT,
# model_name TEXT
# )
# ''')
# conn.commit()
# conn.close()
def image_to_binary(image_path):
with open(image_path, 'rb') as file:
return file.read()
# def store_in_db(image_path, prompt, output, model_name):
# conn = sqlite3.connect('image_outputs.db')
# cursor = conn.cursor()
# image_blob = image_to_binary(image_path)
# cursor.execute('''
# INSERT INTO image_outputs (image, prompt, output, model_name)
# VALUES (?, ?, ?, ?)
# ''', (image_blob, prompt, output, model_name))
# conn.commit()
# conn.close()
# Function to encode an image to base64 for HTML display
def encode_image(image):
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_str = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
return f''
# Function to load and display images from the panorama_hdr_dataset
def load_dataset_images(dataset_name, num_images):
try:
dataset = load_dataset(dataset_name, split='train')
images = []
for i, item in enumerate(dataset[:num_images]):
if 'image' in item:
img = item['image']
print (type(img))
encoded_img = encode_image(img)
metadata = f"Width: {img.width}, Height: {img.height}"
if 'hdr' in item:
metadata += f", HDR: {item['hdr']}"
images.append(f"
Image {i+1}
{encoded_img}
{metadata}
")
if not images:
return "No images could be loaded from this dataset. Please check the dataset structure."
return "".join(images)
except Exception as e:
print(f"Error loading dataset: {e}")
traceback.print_exc()
# Function to generate output
def generate_output(model, processor, prompt, image, model_name, device):
try:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes = image_bytes.getvalue()
if model_name in QWEN2_VL_MODELS:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_bytes},
{"type": "text", "text": prompt},
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[Image.open(io.BytesIO(image_bytes))],
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)]
response_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response_text
elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
messages = [{"role": "user", "content": f"<|image_1|>\n{prompt}"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=1024)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
return response_text
elif model_name == 'vikhyatk/moondream2':
tokenizer = AutoTokenizer.from_pretrained(MOON_DREAM_MODEL_ID, revision=MOON_DREAM_REVISION)
enc_image = model.encode_image(image)
response_text = model.answer_question(enc_image, prompt, tokenizer)
return response_text
except Exception as e:
return f"Error during generation with model {model_name}: {e}"
# Function to list and encode images from a directory
def list_images(directory_path):
images = []
for filename in os.listdir(directory_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(directory_path, filename)
encoded_img = encode_image(image_path)
images.append({
"filename": filename,
"image": encoded_img
})
return images
# Function to extract images from a ZIP file
# Function to extract images from a ZIP file
def extract_images_from_zip(zip_file):
images = []
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
for file_info in zip_ref.infolist():
if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
with zip_ref.open(file_info) as file:
try:
img = Image.open(file)
img = img.convert("RGB") # Ensure the image is in RGB mode
encoded_img = img
images.append({
"filename": file_info.filename,
"image": encoded_img
})
except Exception as e:
print(f"Error opening image {file_info.filename}: {e}")
return images
# Gradio interface function for running inference
def run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
data = []
torch_dtype_value = torch.float16 if torch_dtype == "torch.float16" else torch.float32
device_map_value = "cuda" if torch.cuda.is_available() else "cpu" if device_map == "auto" else device_map
model_processors = {}
for model_name in model_names:
try:
if model_name in QWEN2_VL_MODELS:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch_dtype_value,
device_map=device_map_value
).eval()
processor = AutoProcessor.from_pretrained(model_name)
elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map_value,
torch_dtype=torch_dtype_value,
trust_remote_code=trust_remote_code,
use_flash_attn=use_flash_attn
).eval()
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
elif model_name == 'vikhyatk/moondream2':
model = AutoModelForCausalLM.from_pretrained(
MOON_DREAM_MODEL_ID,
trust_remote_code=True,
revision=MOON_DREAM_REVISION
).eval()
processor = None # No processor needed for this model
model_processors[model_name] = (model, processor)
except Exception as e:
print(f"Error loading model {model_name}: {e}")
try:
# Load images from the ZIP file if use_zip is True
if use_zip:
images = extract_images_from_zip(zip_file)
print ("Number of images in zip:" , len(images))
for img in tqdm(images):
try:
img_data = img['image']
if not isinstance(img_data, str):
# Convert the Image object to a base64-encoded string
img_buffer = io.BytesIO()
img['image'].save(img_buffer, format="PNG")
img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
img_data=f''
row_data = {"Image": img_data} # Assuming encode_image is defined elsewhere
for model_name in model_names:
if model_name in model_processors:
model, processor = model_processors[model_name]
for prompt in prompts:
try:
# Ensure image is defined
image = img['image']
response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
row_data[f"{model_name}_Response_{prompt}"] = response_text
except Exception as e:
row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"
traceback.print_exc()
data.append(row_data)
except Exception as e:
print(f"Error processing image {img['filename']}: {e}")
traceback.print_exc()
# Load the dataset if use_zip is False
else:
dataset = load_dataset(dataset_input, split='train')
for i in tqdm(range(num_images_input)):
if dataset_input == "OpenGVLab/CRPE":
image = dataset[i]['image']
elif dataset_input == "gokaygokay/panorama_hdr_dataset":
image = dataset[i]['png_image']
else:
image = dataset[i]['image']
encoded_img = encode_image(image)
row_data = {"Image": encoded_img}
for model_name in model_names:
if model_name in model_processors:
model, processor = model_processors[model_name]
for prompt in prompts:
try:
response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
row_data[f"{model_name}_Response_{prompt}"] = response_text
except Exception as e:
row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"
data.append(row_data)
except Exception as e:
print(f"Error loading dataset: {e}")
traceback.print_exc()
return pd.DataFrame(data).to_html(escape=False)
# Gradio UI setup
def create_gradio_interface():
css = """
#output {
height: 500px;
overflow: auto;
}
"""
with gr.Blocks(css=css) as demo:
# Title
gr.Markdown("# VLM-Image-Analysis: A Vision-and-Language Modeling Framework.")
gr.Markdown("""
- Handle a batch of images from a ZIP file OR
- Processes images from an HF DB
- Compatible with png, jpg, jpeg, and webp formats
- Compatibility with various AI models: Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct-GPTQ-Int4, InternVL2-1B, Qwen2-VL-72B, /Phi-3-vision-128k-instruct and moondream2""")
image_path = os.path.abspath("static/image.jpg")
gr.Image(value=image_path, label="HF Image", width=300, height=300)
with gr.Tab("VLM model and Dataset selection"):
gr.Markdown("### Dataset Selection: HF or from a ZIP file.")
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
# with gr.Column():
use_zip_input = gr.Checkbox(label="Use ZIP File", value=False)
dataset_input = gr.Dropdown(choices=dataset_options, label="Select Dataset", value=dataset_options[1], visible=True)
num_images_input = gr.Radio(choices=[1, 5, 20], label="Number of Images", value=5)
zip_file_input = gr.File(label="Upload ZIP File of Images", file_types=[".zip"])
gr.Markdown("### VLM Model Selection")
with gr.Row():
with gr.Column():
models_input = gr.CheckboxGroup(choices=available_models, label="Select Models", value=available_models[4])
prompts_input = gr.CheckboxGroup(choices=text_prompts, label="Select Prompts", value=text_prompts[2])
submit_btn = gr.Button("Run Inference")
with gr.Row():
output_display = gr.HTML(label="Results")
with gr.Tab("GPU Device Settings"):
device_map_input = gr.Radio(choices=["auto", "cpu", "cuda"], label="Device Map", value="auto")
torch_dtype_input = gr.Radio(choices=["torch.float16", "torch.float32"], label="Torch Dtype", value="torch.float16")
trust_remote_code_input = gr.Checkbox(label="Trust Remote Code", value=True)
use_flash_attn = gr.Checkbox(label="Use flash-attn 2 (Ampere GPUs or newer.)", value=False)
def run_inference_wrapper(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
return run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file)
def toggle_dataset_visibility(use_zip):
return gr.update(visible=not use_zip)
submit_btn.click(
fn=run_inference_wrapper,
inputs=[models_input, dataset_input, num_images_input, prompts_input, device_map_input, torch_dtype_input, trust_remote_code_input,use_flash_attn, use_zip_input, zip_file_input],
outputs=output_display
)
use_zip_input.change(
fn=toggle_dataset_visibility,
inputs=use_zip_input,
outputs=dataset_input
)
demo.launch(debug=True, share=False)
if __name__ == "__main__":
create_gradio_interface()