gabrielaltay commited on
Commit
bcf0900
·
1 Parent(s): 515ab32

more stuff

Browse files
Files changed (1) hide show
  1. app.py +76 -31
app.py CHANGED
@@ -23,8 +23,9 @@ def initialize_session_state():
23
  keys = [
24
  "colpali_model",
25
  "page_images",
 
26
  "retrieved_page_images",
27
- "response",
28
  ]
29
  for key in keys:
30
  if key not in SS:
@@ -68,7 +69,7 @@ def load_colpali_model():
68
  return model, processor
69
 
70
 
71
- def embed_page_images(model, processor, page_images, batch_size=2):
72
  dataloader = DataLoader(
73
  page_images,
74
  batch_size=batch_size,
@@ -76,11 +77,13 @@ def embed_page_images(model, processor, page_images, batch_size=2):
76
  collate_fn=lambda x: process_images(processor, x),
77
  )
78
  page_embeddings = []
79
- for batch in dataloader:
 
80
  with torch.no_grad():
81
  batch = {k: v.to(model.device) for k, v in batch.items()}
82
  embeddings = model(**batch)
83
  page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
 
84
  return np.array(page_embeddings)
85
 
86
 
@@ -102,14 +105,15 @@ def embed_query_texts(model, processor, query_texts, batch_size=1):
102
  return np.array(query_embeddings)[0]
103
 
104
 
105
-
106
  def get_pdf_page_images_from_bytes(
107
  pdf_bytes: bytes,
108
  use_tmp_dir=False,
109
  ):
110
  if use_tmp_dir:
111
  with tempfile.TemporaryDirectory() as tmp_path:
112
- page_images = pdf2image.convert_from_bytes(pdf_bytes, output_folder=tmp_path)
 
 
113
  else:
114
  page_images = pdf2image.convert_from_bytes(pdf_bytes)
115
  return page_images
@@ -125,13 +129,17 @@ def get_pdf_bytes_from_url(url: str) -> bytes | None:
125
  return None
126
 
127
 
128
- def display_pages(page_images, key):
129
  n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key)
130
  cols = st.columns(n_cols)
131
  for ii_page, page_image in enumerate(page_images):
132
  ii_col = ii_page % n_cols
133
  with cols[ii_col]:
134
- st.image(page_image)
 
 
 
 
135
 
136
 
137
  initialize_session_state()
@@ -142,24 +150,59 @@ if SS["colpali_model"] is None:
142
 
143
 
144
  with st.sidebar:
145
- url = st.text_input("arxiv url", "https://arxiv.org/pdf/2112.01488.pdf")
146
 
147
- if st.button("load paper"):
148
- pdf_bytes = get_pdf_bytes_from_url(url)
149
- SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
150
 
 
 
 
 
151
 
152
- if st.button("embed pages"):
153
- SS["page_embeddings"] = embed_page_images(
154
- SS["colpali_model"],
155
- SS["processor"],
156
- SS["page_images"],
 
 
 
 
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  with st.container(border=True):
161
  query = st.text_area("query")
162
- top_k = st.slider("num pages to retrieve", min_value=1, max_value=8, value=3, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if st.button("answer query"):
164
  SS["query_embeddings"] = embed_query_texts(
165
  SS["colpali_model"],
@@ -171,7 +214,7 @@ with st.container(border=True):
171
  for ipage in range(len(SS["page_embeddings"])):
172
  # for every query token find the max_sim with every page patch
173
  patch_query_scores = np.dot(
174
- SS['page_embeddings'][ipage],
175
  SS["query_embeddings"].T,
176
  )
177
  max_sim_score = patch_query_scores.max(axis=0).sum()
@@ -181,25 +224,23 @@ with st.container(border=True):
181
  i_ranked_pages = np.argsort(-page_query_scores)
182
 
183
  page_images = []
184
- for ii in range(top_k):
 
 
185
  page_images.append(SS["page_images"][i_ranked_pages[ii]])
 
186
  SS["retrieved_page_images"] = page_images
 
187
 
188
-
189
- prompt = [
190
- query +
191
- " Think through your answer step by step. "
192
- "Support your answer with descriptions of the images. "
193
- "Do not infer information that is not in the images.",
194
- ] + page_images
195
 
196
  genai.configure(api_key=st.secrets["google_genai_api_key"])
197
- # genai_model_name = "gemini-1.5-flash"
198
  genai_model_name = "gemini-1.5-pro"
199
  gen_model = genai.GenerativeModel(
200
  model_name=genai_model_name,
201
  generation_config=genai.GenerationConfig(
202
- temperature=0.1,
203
  ),
204
  )
205
  response = gen_model.generate_content(prompt)
@@ -208,12 +249,16 @@ with st.container(border=True):
208
 
209
 
210
  if SS["response"] is not None:
 
211
  st.write(SS["response"])
212
  st.header("Retrieved Pages")
213
- display_pages(SS["retrieved_page_images"], "retrieved_pages")
214
-
 
 
 
215
 
216
 
217
  if SS["page_images"] is not None:
218
- st.header("All PDF Pages")
219
  display_pages(SS["page_images"], "all_pages")
 
23
  keys = [
24
  "colpali_model",
25
  "page_images",
26
+ "page_embeddings",
27
  "retrieved_page_images",
28
+ "retrieved_page_scores" "response",
29
  ]
30
  for key in keys:
31
  if key not in SS:
 
69
  return model, processor
70
 
71
 
72
+ def embed_page_images(model, processor, page_images, batch_size=1):
73
  dataloader = DataLoader(
74
  page_images,
75
  batch_size=batch_size,
 
77
  collate_fn=lambda x: process_images(processor, x),
78
  )
79
  page_embeddings = []
80
+ pbar = st.progress(0, text="embedding pages")
81
+ for ibatch, batch in enumerate(dataloader):
82
  with torch.no_grad():
83
  batch = {k: v.to(model.device) for k, v in batch.items()}
84
  embeddings = model(**batch)
85
  page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
86
+ pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
87
  return np.array(page_embeddings)
88
 
89
 
 
105
  return np.array(query_embeddings)[0]
106
 
107
 
 
108
  def get_pdf_page_images_from_bytes(
109
  pdf_bytes: bytes,
110
  use_tmp_dir=False,
111
  ):
112
  if use_tmp_dir:
113
  with tempfile.TemporaryDirectory() as tmp_path:
114
+ page_images = pdf2image.convert_from_bytes(
115
+ pdf_bytes, output_folder=tmp_path
116
+ )
117
  else:
118
  page_images = pdf2image.convert_from_bytes(pdf_bytes)
119
  return page_images
 
129
  return None
130
 
131
 
132
+ def display_pages(page_images, key, captions=None):
133
  n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key)
134
  cols = st.columns(n_cols)
135
  for ii_page, page_image in enumerate(page_images):
136
  ii_col = ii_page % n_cols
137
  with cols[ii_col]:
138
+ if captions is not None:
139
+ caption = captions[ii_page]
140
+ else:
141
+ caption = None
142
+ st.image(page_image, caption=caption)
143
 
144
 
145
  initialize_session_state()
 
150
 
151
 
152
  with st.sidebar:
 
153
 
154
+ with st.container(border=True):
155
+ st.header("Load PDF (URL or Upload)")
156
+ st.write("When a PDF is loaded, each page will be turned into an image.")
157
 
158
+ url = st.text_input("Provide a URL", "https://arxiv.org/pdf/2404.15549v2")
159
+ if st.button("load paper from url"):
160
+ pdf_bytes = get_pdf_bytes_from_url(url)
161
+ SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
162
 
163
+ uploaded_file = st.file_uploader("Upload a file", type=["pdf"])
164
+ if uploaded_file is not None:
165
+ pdf_bytes = uploaded_file.getvalue()
166
+ SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
167
+
168
+ with st.container(border=True):
169
+ st.header("Embed Page Images")
170
+ st.write(
171
+ "In order to retrieve relevant images for a query, we must first embed the images."
172
  )
173
+ if st.button("embed pages"):
174
+ SS["page_embeddings"] = embed_page_images(
175
+ SS["colpali_model"],
176
+ SS["processor"],
177
+ SS["page_images"],
178
+ )
179
+
180
+ if SS["page_images"] is not None:
181
+ st.write("Num Page Images: {}".format(len(SS["page_images"])))
182
+
183
+ if SS["page_embeddings"] is not None:
184
+ st.write("Page Embeddings Shape: {}".format(SS["page_embeddings"].shape))
185
 
186
 
187
  with st.container(border=True):
188
  query = st.text_area("query")
189
+
190
+ prompt_template_default = """Your goal is to answer queries based on the provided images. Each image is one page from a single PDF document. Provide answers that are at least 3 sentences long. Clearly explain the reasoning behind your answer. Create trustworthy answers by referencing the material in the PDF pages. Do not reference page numbers unless they appear on the page images.
191
+
192
+ ---
193
+
194
+ {query}"""
195
+
196
+ with st.expander("Prompt Template"):
197
+ prompt_template = st.text_area(
198
+ "Customize the prompt template",
199
+ prompt_template_default,
200
+ height=200,
201
+ )
202
+
203
+ top_k = st.slider(
204
+ "num pages to retrieve", min_value=1, max_value=8, value=3, step=1
205
+ )
206
  if st.button("answer query"):
207
  SS["query_embeddings"] = embed_query_texts(
208
  SS["colpali_model"],
 
214
  for ipage in range(len(SS["page_embeddings"])):
215
  # for every query token find the max_sim with every page patch
216
  patch_query_scores = np.dot(
217
+ SS["page_embeddings"][ipage],
218
  SS["query_embeddings"].T,
219
  )
220
  max_sim_score = patch_query_scores.max(axis=0).sum()
 
224
  i_ranked_pages = np.argsort(-page_query_scores)
225
 
226
  page_images = []
227
+ page_scores = []
228
+ num_pages = len(SS["page_images"])
229
+ for ii in range(min(top_k, num_pages)):
230
  page_images.append(SS["page_images"][i_ranked_pages[ii]])
231
+ page_scores.append(page_query_scores[i_ranked_pages[ii]])
232
  SS["retrieved_page_images"] = page_images
233
+ SS["retrieved_page_scores"] = page_scores
234
 
235
+ prompt = [prompt_template.format(query=query)] + page_images
 
 
 
 
 
 
236
 
237
  genai.configure(api_key=st.secrets["google_genai_api_key"])
238
+ # genai_model_name = "gemini-1.5-flash"
239
  genai_model_name = "gemini-1.5-pro"
240
  gen_model = genai.GenerativeModel(
241
  model_name=genai_model_name,
242
  generation_config=genai.GenerationConfig(
243
+ temperature=0.0,
244
  ),
245
  )
246
  response = gen_model.generate_content(prompt)
 
249
 
250
 
251
  if SS["response"] is not None:
252
+ st.header("Response")
253
  st.write(SS["response"])
254
  st.header("Retrieved Pages")
255
+ display_pages(
256
+ SS["retrieved_page_images"],
257
+ "retrieved_pages",
258
+ captions=[f"Score={el:.2f}" for el in SS["retrieved_page_scores"]],
259
+ )
260
 
261
 
262
  if SS["page_images"] is not None:
263
+ st.header("All Pages")
264
  display_pages(SS["page_images"], "all_pages")