import gradio as gr from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info from PIL import Image import io import base64 from datasets import load_dataset max_token_budget = 512 min_pixels = 1 * 28 * 28 max_pixels = max_token_budget * 28 * 28 processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels ) ds = load_dataset("gigant/tib-bench-vlm")["train"] def segments(example): # create a text with the tokens from the timestamps of the extracted keyframes and transcript text = "" segment_i = 0 for i, timestamp in enumerate(example['keyframes']['timestamp']): text += f"" #f"" start, end = timestamp[0], timestamp[1] while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01: text += example["transcript_segments"]["text"][segment_i] segment_i += 1 return text def create_interleaved_html(text, slides, scale=0.4, max_width=600): """ Creates an HTML string with interleaved images and text segments. The images are converted to base64 and embedded directly in the HTML. """ html = [] segments = text.split("") for j, segment in enumerate(segments): # Skip the first empty string bc of leading # Add the image if j > 0: img = slides[j - 1] img_width = int(img.width * scale) img_height = int(img.height * scale) if img_width > max_width: ratio = max_width / img_width img_width = max_width img_height = int(img_height * ratio) # Convert image to base64 buffer = io.BytesIO() img.resize((img_width, img_height)).save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") html.append(f'') # Add the text segment after the image html.append(f'
{segment}
') return "".join(html) def doc_to_messages(text, slides): content = [] segments = text.split("") for j, segment in enumerate(segments): if j > 0: content.append({"type": "image", "image": slides[j - 1]}) content.append({"type": "text", "text": segment}) messages = [ { "role": "user", "content": content, } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) print(text) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) return inputs # Global variables to keep track of current document current_doc_index = 0 annotations = [] def load_document(index): """Load a specific document from the dataset""" if 0 <= index < len(ds): doc = ds[index] segments_doc = segments(doc) return ( doc["title"], doc["abstract"], create_interleaved_html(segments_doc, doc["slides"], scale=0.7), doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1], ) return ("", "", "", "") def get_next_document(): """Get the next document in the dataset""" global current_doc_index current_doc_index = (current_doc_index + 1) % len(ds) return load_document(current_doc_index) def get_prev_document(): """Get the previous document in the dataset""" global current_doc_index current_doc_index = (current_doc_index - 1) % len(ds) return load_document(current_doc_index) theme = gr.themes.Ocean() with gr.Blocks(theme=theme) as demo: gr.Markdown("# Slide Presentation Visualization Tool") with gr.Row(): with gr.Column(): body = gr.HTML(max_height=400) # Function to update the interleaved view def update_interleaved_view(title, abstract, body, token_count): return body with gr.Column(): title = gr.Textbox(label="Title", interactive=False, max_lines=1) abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8) token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1) title.change( fn=update_interleaved_view, inputs=[title, abstract, body, token_count], outputs=body, ) # Load first document title_val, abstract_val, body_val, token_count_val = load_document(current_doc_index) title.value = title_val abstract.value = abstract_val body.value = body_val token_count.value = str(token_count_val) with gr.Row(): prev_button = gr.Button("Previous Document") prev_button.click(fn=get_prev_document, inputs=[], outputs=[title, abstract, body, token_count]) next_button = gr.Button("Next Document") next_button.click(fn=get_next_document, inputs=[], outputs=[title, abstract, body, token_count]) demo.launch()