File size: 17,123 Bytes
fadb9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc2fb4b
 
 
fadb9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import os
import base64
import io
import sqlite3
import torch
import gradio as gr
import pandas as pd
from PIL import Image
import requests
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import traceback
from tqdm import tqdm 
import zipfile

# Define constants for vikhyatk/moondream2 model
MOON_DREAM_MODEL_ID = "vikhyatk/moondream2"
MOON_DREAM_REVISION = "2024-08-26"

# Define constants for the Qwen2-VL models
QWEN2_VL_MODELS = [
    'Qwen/Qwen2-VL-7B-Instruct',
    'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4',
    'OpenGVLab/InternVL2-1B',
    'Qwen/Qwen2-VL-72B',
]

# List of models to use (combining unique entries from available models and QWEN2_VL_MODELS)
available_models = [
    *QWEN2_VL_MODELS,  # Expands the QWEN2_VL_MODELS list into the available_models
    'microsoft/Phi-3-vision-128k-instruct',
    'vikhyatk/moondream2'
]

# List of available Hugging Face datasets
dataset_options = [
    "gokaygokay/panorama_hdr_dataset",  
    "OpenGVLab/CRPE"  
]

# List of text prompts to use
text_prompts = [
    "Provide a detailed description of the image contents, including all visible objects, people, activities, and extract any text present within the image using Optical Character Recognition (OCR). Organize the extracted text in a structured table format with columns for original text, its translation into English, and the language it is written in.",
    "Offer a thorough description of all elements within the image, from objects to individuals and their activities. Ensure any legible text seen in the image is extracted using Optical Character Recognition (OCR). Provide an accurate narrative that encapsulates the full content of the image.",    
    "Create a four-sentence caption for the image. Start by specifying the style and type, such as painting, photograph, or digital art. In the next sentences, detail the contents and the composition clearly and concisely. Use language suited for prompting a text-to-image model, separating descriptive terms with commas instead of 'or'. Keep the description direct, avoiding interpretive phrases or abstract expressions",
]

# SQLite setup
# def init_db():
#     conn = sqlite3.connect('image_outputs.db')
#     cursor = conn.cursor()
#     cursor.execute('''
#         CREATE TABLE IF NOT EXISTS image_outputs (
#             id INTEGER PRIMARY KEY AUTOINCREMENT,
#             image BLOB,
#             prompt TEXT,
#             output TEXT,
#             model_name TEXT
#         )
#     ''')
#     conn.commit()
#     conn.close()

def image_to_binary(image_path):
    with open(image_path, 'rb') as file:
        return file.read()

# def store_in_db(image_path, prompt, output, model_name):
#     conn = sqlite3.connect('image_outputs.db')
#     cursor = conn.cursor()
#     image_blob = image_to_binary(image_path)
#     cursor.execute('''
#         INSERT INTO image_outputs (image, prompt, output, model_name)
#         VALUES (?, ?, ?, ?)
#     ''', (image_blob, prompt, output, model_name))
#     conn.commit()
#     conn.close()

# Function to encode an image to base64 for HTML display
def encode_image(image):
    img_buffer = io.BytesIO()
    image.save(img_buffer, format="PNG")
    img_str = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
    return f'<img src="data:image/png;base64,{img_str}" style="max-width:500px;"/>'

# Function to load and display images from the panorama_hdr_dataset
def load_dataset_images(dataset_name, num_images):
    try:
        dataset = load_dataset(dataset_name, split='train')
        images = []
        for i, item in enumerate(dataset[:num_images]):
            if 'image' in item:
                img = item['image']
                print (type(img))
                encoded_img = encode_image(img)
                metadata = f"Width: {img.width}, Height: {img.height}"
                if 'hdr' in item:
                    metadata += f", HDR: {item['hdr']}"
                images.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>Image {i+1}</h3>{encoded_img}<p>{metadata}</p></div>")
        if not images:
            return "No images could be loaded from this dataset. Please check the dataset structure."
        return "".join(images)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        traceback.print_exc()

# Function to generate output
def generate_output(model, processor, prompt, image, model_name, device):
    try:
        image_bytes = io.BytesIO()
        image.save(image_bytes, format="PNG")
        image_bytes = image_bytes.getvalue()

        if model_name in QWEN2_VL_MODELS:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image_bytes},
                        {"type": "text", "text": prompt},
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = processor(
                text=[text],
                images=[Image.open(io.BytesIO(image_bytes))],
                padding=True,
                return_tensors="pt",
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            generated_ids = model.generate(**inputs, max_new_tokens=1024)
            generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)]
            response_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
            return response_text

        elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
            messages = [{"role": "user", "content": f"<|image_1|>\n{prompt}"}]
            prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = processor(prompt, [image], return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=1024)
            generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
            response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
            return response_text

        elif model_name == 'vikhyatk/moondream2':
            tokenizer = AutoTokenizer.from_pretrained(MOON_DREAM_MODEL_ID, revision=MOON_DREAM_REVISION)
            enc_image = model.encode_image(image)
            response_text = model.answer_question(enc_image, prompt, tokenizer)
            return response_text
    except Exception as e:
        return f"Error during generation with model {model_name}: {e}"

# Function to list and encode images from a directory
def list_images(directory_path):
    images = []
    for filename in os.listdir(directory_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(directory_path, filename)
            encoded_img = encode_image(image_path)
            images.append({
                "filename": filename,
                "image": encoded_img
            })
    return images

# Function to extract images from a ZIP file
# Function to extract images from a ZIP file
def extract_images_from_zip(zip_file):
    images = []
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        for file_info in zip_ref.infolist():
            if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                with zip_ref.open(file_info) as file:
                    try:
                        img = Image.open(file)
                        img = img.convert("RGB")  # Ensure the image is in RGB mode
                        encoded_img = img
                        images.append({
                            "filename": file_info.filename,
                            "image": encoded_img
                        })
                    except Exception as e:
                        print(f"Error opening image {file_info.filename}: {e}")
    return images

# Gradio interface function for running inference
def run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
    data = []

    torch_dtype_value = torch.float16 if torch_dtype == "torch.float16" else torch.float32
    device_map_value = "cuda" if torch.cuda.is_available() else "cpu" if device_map == "auto" else device_map

    model_processors = {}
    for model_name in model_names:
        try:
            if model_name in QWEN2_VL_MODELS:
                model = Qwen2VLForConditionalGeneration.from_pretrained(
                    model_name, 
                    torch_dtype=torch_dtype_value, 
                    device_map=device_map_value
                ).eval()
                processor = AutoProcessor.from_pretrained(model_name)
            elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
                model = AutoModelForCausalLM.from_pretrained(
                    model_name, 
                    device_map=device_map_value, 
                    torch_dtype=torch_dtype_value, 
                    trust_remote_code=trust_remote_code, 
                    use_flash_attn=use_flash_attn
                ).eval()
                processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
            elif model_name == 'vikhyatk/moondream2':
                model = AutoModelForCausalLM.from_pretrained(
                    MOON_DREAM_MODEL_ID, 
                    trust_remote_code=True, 
                    revision=MOON_DREAM_REVISION
                ).eval()
                processor = None  # No processor needed for this model

            model_processors[model_name] = (model, processor)

        except Exception as e:
            print(f"Error loading model {model_name}: {e}")

    try:
        # Load images from the ZIP file if use_zip is True
        if use_zip:
            images = extract_images_from_zip(zip_file)
            print ("Number of images in zip:" , len(images))
            for img in tqdm(images):
                try:
                    img_data = img['image']
                    if not isinstance(img_data, str):
                        # Convert the Image object to a base64-encoded string
                        img_buffer = io.BytesIO()
                        img['image'].save(img_buffer, format="PNG")
                        img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
                            
                    img_data=f'<img src="data:image/png;base64,{img_data}" style="max-width:500px;"/>'

                    row_data = {"Image": img_data}  # Assuming encode_image is defined elsewhere
                    for model_name in model_names:
                        if model_name in model_processors:
                            model, processor = model_processors[model_name]
                            for prompt in prompts:
                                try:
                                    # Ensure image is defined
                                    image = img['image']
                                    response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
                                    row_data[f"{model_name}_Response_{prompt}"] = response_text
                                except Exception as e:
                                    row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"
                                    traceback.print_exc()

                    data.append(row_data)
                except Exception as e:
                    print(f"Error processing image {img['filename']}: {e}")
                    traceback.print_exc()

        # Load the dataset if use_zip is False
        else:
            dataset = load_dataset(dataset_input, split='train')
            for i in tqdm(range(num_images_input)):
                if dataset_input == "OpenGVLab/CRPE":
                    image = dataset[i]['image']
                elif dataset_input == "gokaygokay/panorama_hdr_dataset":
                    image = dataset[i]['png_image']
                else:
                    image = dataset[i]['image']

                encoded_img = encode_image(image)
                row_data = {"Image": encoded_img}

                for model_name in model_names:
                    if model_name in model_processors:
                        model, processor = model_processors[model_name]
                        for prompt in prompts:
                            try:
                                response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
                                row_data[f"{model_name}_Response_{prompt}"] = response_text
                            except Exception as e:
                                row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"

                data.append(row_data)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        traceback.print_exc()

    return pd.DataFrame(data).to_html(escape=False)

# Gradio UI setup
def create_gradio_interface():
    css = """
      #output {
        height: 500px;
        overflow: auto;
      }
    """
    with gr.Blocks(css=css) as demo:
        # Title
        gr.Markdown("# VLM-Image-Analysis: A Vision-and-Language Modeling Framework.")
        gr.Markdown("""
                    - Handle a batch of images from a ZIP file OR
                    - Processes images from an HF DB 
                    - Compatible with png, jpg, jpeg, and webp formats                                        
                    - Compatibility with various AI models: Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct-GPTQ-Int4, InternVL2-1B, Qwen2-VL-72B, /Phi-3-vision-128k-instruct and moondream2""")

        image_path = os.path.abspath("static/image.jpg")
        gr.Image(value=image_path, label="HF Image", width=300, height=300)
        
        with gr.Tab("VLM model and Dataset selection"):
            gr.Markdown("### Dataset Selection: HF or from a ZIP file.")
            with gr.Accordion("Advanced Settings", open=True):
                with gr.Row():
                    # with gr.Column():  
                    use_zip_input = gr.Checkbox(label="Use ZIP File", value=False)
                    dataset_input = gr.Dropdown(choices=dataset_options, label="Select Dataset", value=dataset_options[1], visible=True)
                    num_images_input = gr.Radio(choices=[1, 5, 20], label="Number of Images", value=5)
                    zip_file_input = gr.File(label="Upload ZIP File of Images", file_types=[".zip"])
            gr.Markdown("### VLM Model Selection")
            with gr.Row():
                with gr.Column():  
                    models_input = gr.CheckboxGroup(choices=available_models, label="Select Models", value=available_models[4])
                    prompts_input = gr.CheckboxGroup(choices=text_prompts, label="Select Prompts", value=text_prompts[2])
                    submit_btn = gr.Button("Run Inference")
                
            with gr.Row():        
                output_display = gr.HTML(label="Results")

        with gr.Tab("GPU Device Settings"):
            device_map_input = gr.Radio(choices=["auto", "cpu", "cuda"], label="Device Map", value="auto")
            torch_dtype_input = gr.Radio(choices=["torch.float16", "torch.float32"], label="Torch Dtype", value="torch.float16")
            trust_remote_code_input = gr.Checkbox(label="Trust Remote Code", value=True)
            use_flash_attn = gr.Checkbox(label="Use flash-attn 2 (Ampere GPUs or newer.)", value=False)
            
                               
        def run_inference_wrapper(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
            return run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file)
        
        def toggle_dataset_visibility(use_zip):
            return gr.update(visible=not use_zip)
        
        submit_btn.click(
            fn=run_inference_wrapper, 
            inputs=[models_input, dataset_input, num_images_input, prompts_input, device_map_input, torch_dtype_input, trust_remote_code_input,use_flash_attn, use_zip_input, zip_file_input], 
            outputs=output_display
        )
        
        use_zip_input.change(
            fn=toggle_dataset_visibility, 
            inputs=use_zip_input, 
            outputs=dataset_input
        )

    demo.launch(debug=True, share=False)

if __name__ == "__main__":
    create_gradio_interface()