gigant commited on
Commit
6cdcc54
·
verified ·
1 Parent(s): b86baf2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ from IPython.display import display
8
+ from datasets import load_dataset
9
+
10
+
11
+ max_token_budget = 512
12
+
13
+ min_pixels = 1 * 28 * 28
14
+ max_pixels = max_token_budget * 28 * 28
15
+ processor = AutoProcessor.from_pretrained(
16
+ "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
17
+ )
18
+
19
+ ds = load_dataset("gigant/tib-bench-vlm")["train"]
20
+
21
+
22
+ def segments(example):
23
+ # create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript
24
+ text = ""
25
+ segment_i = 0
26
+ for i, timestamp in enumerate(example['keyframes']['timestamp']):
27
+ text += f"<image>" #f"<image {i}>"
28
+ start, end = timestamp[0], timestamp[1]
29
+ while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01:
30
+ text += example["transcript_segments"]["text"][segment_i]
31
+ segment_i += 1
32
+ return text
33
+
34
+ def create_interleaved_html(text, slides, scale=0.4, max_width=600):
35
+ """
36
+ Creates an HTML string with interleaved images and text segments.
37
+ The images are converted to base64 and embedded directly in the HTML.
38
+ """
39
+ html = []
40
+ segments = text.split("<image>")
41
+
42
+ for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image>
43
+ # Add the image
44
+ if j > 0:
45
+ img = slides[j - 1]
46
+ img_width = int(img.width * scale)
47
+ img_height = int(img.height * scale)
48
+ if img_width > max_width:
49
+ ratio = max_width / img_width
50
+ img_width = max_width
51
+ img_height = int(img_height * ratio)
52
+
53
+ # Convert image to base64
54
+ buffer = io.BytesIO()
55
+ img.resize((img_width, img_height)).save(buffer, format="PNG")
56
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
57
+
58
+ html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">')
59
+ # Add the text segment after the image
60
+ html.append(f'<div style="white-space: pre-wrap;">{segment}</div>')
61
+
62
+ return "".join(html)
63
+
64
+ def doc_to_messages(text, slides):
65
+ content = []
66
+ segments = text.split("<image>")
67
+ for j, segment in enumerate(segments):
68
+ if j > 0:
69
+ content.append({"type": "image", "image": slides[j - 1]})
70
+ content.append({"type": "text", "text": segment})
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": content,
75
+ }
76
+ ]
77
+ # Preparation for inference
78
+ text = processor.apply_chat_template(
79
+ messages, tokenize=False, add_generation_prompt=True
80
+ )
81
+ print(text)
82
+ image_inputs, video_inputs = process_vision_info(messages)
83
+ inputs = processor(
84
+ text=[text],
85
+ images=image_inputs,
86
+ videos=video_inputs,
87
+ padding=True,
88
+ return_tensors="pt",
89
+ )
90
+ return inputs
91
+
92
+
93
+ # Global variables to keep track of current document
94
+ current_doc_index = 0
95
+ annotations = []
96
+
97
+ def load_document(index):
98
+ """Load a specific document from the dataset"""
99
+ if 0 <= index < len(ds):
100
+ doc = ds[index]
101
+ segments_doc = segments(doc)
102
+ return (
103
+ doc["title"],
104
+ doc["abstract"],
105
+ create_interleaved_html(segments_doc, doc["slides"], scale=0.7),
106
+ doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1],
107
+ )
108
+ return ("", "", "", "")
109
+
110
+ def get_next_document():
111
+ """Get the next document in the dataset"""
112
+ global current_doc_index
113
+ current_doc_index = (current_doc_index + 1) % len(ds)
114
+ return load_document(current_doc_index)
115
+
116
+ def get_prev_document():
117
+ """Get the previous document in the dataset"""
118
+ global current_doc_index
119
+ current_doc_index = (current_doc_index - 1) % len(ds)
120
+ return load_document(current_doc_index)
121
+
122
+
123
+ theme = gr.themes.Ocean()
124
+
125
+ with gr.Blocks(theme=theme) as demo:
126
+ gr.Markdown("# Slide Presentation Visualization Tool")
127
+ with gr.Row():
128
+ with gr.Column():
129
+ body = gr.HTML(max_height=400)
130
+
131
+ # Function to update the interleaved view
132
+ def update_interleaved_view(title, abstract, body, token_count):
133
+ return body
134
+
135
+ with gr.Column():
136
+ title = gr.Textbox(label="Title", interactive=False, max_lines=1)
137
+ abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8)
138
+ token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1)
139
+
140
+ title.change(
141
+ fn=update_interleaved_view,
142
+ inputs=[title, abstract, body, token_count],
143
+ outputs=body,
144
+ )
145
+ # Load first document
146
+ title_val, abstract_val, body_val, token_count_val = load_document(current_doc_index)
147
+ title.value = title_val
148
+ abstract.value = abstract_val
149
+ body.value = body_val
150
+ token_count.value = str(token_count_val)
151
+
152
+
153
+ with gr.Row():
154
+ prev_button = gr.Button("Previous Document")
155
+ prev_button.click(fn=get_prev_document, inputs=[], outputs=[title, abstract, body, token_count])
156
+ next_button = gr.Button("Next Document")
157
+ next_button.click(fn=get_next_document, inputs=[], outputs=[title, abstract, body, token_count])
158
+
159
+ demo.launch()