Spaces:
Running
Running
Commit
·
bcf0900
1
Parent(s):
515ab32
more stuff
Browse files
app.py
CHANGED
@@ -23,8 +23,9 @@ def initialize_session_state():
|
|
23 |
keys = [
|
24 |
"colpali_model",
|
25 |
"page_images",
|
|
|
26 |
"retrieved_page_images",
|
27 |
-
"response",
|
28 |
]
|
29 |
for key in keys:
|
30 |
if key not in SS:
|
@@ -68,7 +69,7 @@ def load_colpali_model():
|
|
68 |
return model, processor
|
69 |
|
70 |
|
71 |
-
def embed_page_images(model, processor, page_images, batch_size=
|
72 |
dataloader = DataLoader(
|
73 |
page_images,
|
74 |
batch_size=batch_size,
|
@@ -76,11 +77,13 @@ def embed_page_images(model, processor, page_images, batch_size=2):
|
|
76 |
collate_fn=lambda x: process_images(processor, x),
|
77 |
)
|
78 |
page_embeddings = []
|
79 |
-
|
|
|
80 |
with torch.no_grad():
|
81 |
batch = {k: v.to(model.device) for k, v in batch.items()}
|
82 |
embeddings = model(**batch)
|
83 |
page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
|
|
84 |
return np.array(page_embeddings)
|
85 |
|
86 |
|
@@ -102,14 +105,15 @@ def embed_query_texts(model, processor, query_texts, batch_size=1):
|
|
102 |
return np.array(query_embeddings)[0]
|
103 |
|
104 |
|
105 |
-
|
106 |
def get_pdf_page_images_from_bytes(
|
107 |
pdf_bytes: bytes,
|
108 |
use_tmp_dir=False,
|
109 |
):
|
110 |
if use_tmp_dir:
|
111 |
with tempfile.TemporaryDirectory() as tmp_path:
|
112 |
-
page_images = pdf2image.convert_from_bytes(
|
|
|
|
|
113 |
else:
|
114 |
page_images = pdf2image.convert_from_bytes(pdf_bytes)
|
115 |
return page_images
|
@@ -125,13 +129,17 @@ def get_pdf_bytes_from_url(url: str) -> bytes | None:
|
|
125 |
return None
|
126 |
|
127 |
|
128 |
-
def display_pages(page_images, key):
|
129 |
n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key)
|
130 |
cols = st.columns(n_cols)
|
131 |
for ii_page, page_image in enumerate(page_images):
|
132 |
ii_col = ii_page % n_cols
|
133 |
with cols[ii_col]:
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
initialize_session_state()
|
@@ -142,24 +150,59 @@ if SS["colpali_model"] is None:
|
|
142 |
|
143 |
|
144 |
with st.sidebar:
|
145 |
-
url = st.text_input("arxiv url", "https://arxiv.org/pdf/2112.01488.pdf")
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
SS["
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
with st.container(border=True):
|
161 |
query = st.text_area("query")
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if st.button("answer query"):
|
164 |
SS["query_embeddings"] = embed_query_texts(
|
165 |
SS["colpali_model"],
|
@@ -171,7 +214,7 @@ with st.container(border=True):
|
|
171 |
for ipage in range(len(SS["page_embeddings"])):
|
172 |
# for every query token find the max_sim with every page patch
|
173 |
patch_query_scores = np.dot(
|
174 |
-
SS[
|
175 |
SS["query_embeddings"].T,
|
176 |
)
|
177 |
max_sim_score = patch_query_scores.max(axis=0).sum()
|
@@ -181,25 +224,23 @@ with st.container(border=True):
|
|
181 |
i_ranked_pages = np.argsort(-page_query_scores)
|
182 |
|
183 |
page_images = []
|
184 |
-
|
|
|
|
|
185 |
page_images.append(SS["page_images"][i_ranked_pages[ii]])
|
|
|
186 |
SS["retrieved_page_images"] = page_images
|
|
|
187 |
|
188 |
-
|
189 |
-
prompt = [
|
190 |
-
query +
|
191 |
-
" Think through your answer step by step. "
|
192 |
-
"Support your answer with descriptions of the images. "
|
193 |
-
"Do not infer information that is not in the images.",
|
194 |
-
] + page_images
|
195 |
|
196 |
genai.configure(api_key=st.secrets["google_genai_api_key"])
|
197 |
-
# genai_model_name = "gemini-1.5-flash"
|
198 |
genai_model_name = "gemini-1.5-pro"
|
199 |
gen_model = genai.GenerativeModel(
|
200 |
model_name=genai_model_name,
|
201 |
generation_config=genai.GenerationConfig(
|
202 |
-
temperature=0.
|
203 |
),
|
204 |
)
|
205 |
response = gen_model.generate_content(prompt)
|
@@ -208,12 +249,16 @@ with st.container(border=True):
|
|
208 |
|
209 |
|
210 |
if SS["response"] is not None:
|
|
|
211 |
st.write(SS["response"])
|
212 |
st.header("Retrieved Pages")
|
213 |
-
display_pages(
|
214 |
-
|
|
|
|
|
|
|
215 |
|
216 |
|
217 |
if SS["page_images"] is not None:
|
218 |
-
st.header("All
|
219 |
display_pages(SS["page_images"], "all_pages")
|
|
|
23 |
keys = [
|
24 |
"colpali_model",
|
25 |
"page_images",
|
26 |
+
"page_embeddings",
|
27 |
"retrieved_page_images",
|
28 |
+
"retrieved_page_scores" "response",
|
29 |
]
|
30 |
for key in keys:
|
31 |
if key not in SS:
|
|
|
69 |
return model, processor
|
70 |
|
71 |
|
72 |
+
def embed_page_images(model, processor, page_images, batch_size=1):
|
73 |
dataloader = DataLoader(
|
74 |
page_images,
|
75 |
batch_size=batch_size,
|
|
|
77 |
collate_fn=lambda x: process_images(processor, x),
|
78 |
)
|
79 |
page_embeddings = []
|
80 |
+
pbar = st.progress(0, text="embedding pages")
|
81 |
+
for ibatch, batch in enumerate(dataloader):
|
82 |
with torch.no_grad():
|
83 |
batch = {k: v.to(model.device) for k, v in batch.items()}
|
84 |
embeddings = model(**batch)
|
85 |
page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
86 |
+
pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
|
87 |
return np.array(page_embeddings)
|
88 |
|
89 |
|
|
|
105 |
return np.array(query_embeddings)[0]
|
106 |
|
107 |
|
|
|
108 |
def get_pdf_page_images_from_bytes(
|
109 |
pdf_bytes: bytes,
|
110 |
use_tmp_dir=False,
|
111 |
):
|
112 |
if use_tmp_dir:
|
113 |
with tempfile.TemporaryDirectory() as tmp_path:
|
114 |
+
page_images = pdf2image.convert_from_bytes(
|
115 |
+
pdf_bytes, output_folder=tmp_path
|
116 |
+
)
|
117 |
else:
|
118 |
page_images = pdf2image.convert_from_bytes(pdf_bytes)
|
119 |
return page_images
|
|
|
129 |
return None
|
130 |
|
131 |
|
132 |
+
def display_pages(page_images, key, captions=None):
|
133 |
n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key)
|
134 |
cols = st.columns(n_cols)
|
135 |
for ii_page, page_image in enumerate(page_images):
|
136 |
ii_col = ii_page % n_cols
|
137 |
with cols[ii_col]:
|
138 |
+
if captions is not None:
|
139 |
+
caption = captions[ii_page]
|
140 |
+
else:
|
141 |
+
caption = None
|
142 |
+
st.image(page_image, caption=caption)
|
143 |
|
144 |
|
145 |
initialize_session_state()
|
|
|
150 |
|
151 |
|
152 |
with st.sidebar:
|
|
|
153 |
|
154 |
+
with st.container(border=True):
|
155 |
+
st.header("Load PDF (URL or Upload)")
|
156 |
+
st.write("When a PDF is loaded, each page will be turned into an image.")
|
157 |
|
158 |
+
url = st.text_input("Provide a URL", "https://arxiv.org/pdf/2404.15549v2")
|
159 |
+
if st.button("load paper from url"):
|
160 |
+
pdf_bytes = get_pdf_bytes_from_url(url)
|
161 |
+
SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
|
162 |
|
163 |
+
uploaded_file = st.file_uploader("Upload a file", type=["pdf"])
|
164 |
+
if uploaded_file is not None:
|
165 |
+
pdf_bytes = uploaded_file.getvalue()
|
166 |
+
SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
|
167 |
+
|
168 |
+
with st.container(border=True):
|
169 |
+
st.header("Embed Page Images")
|
170 |
+
st.write(
|
171 |
+
"In order to retrieve relevant images for a query, we must first embed the images."
|
172 |
)
|
173 |
+
if st.button("embed pages"):
|
174 |
+
SS["page_embeddings"] = embed_page_images(
|
175 |
+
SS["colpali_model"],
|
176 |
+
SS["processor"],
|
177 |
+
SS["page_images"],
|
178 |
+
)
|
179 |
+
|
180 |
+
if SS["page_images"] is not None:
|
181 |
+
st.write("Num Page Images: {}".format(len(SS["page_images"])))
|
182 |
+
|
183 |
+
if SS["page_embeddings"] is not None:
|
184 |
+
st.write("Page Embeddings Shape: {}".format(SS["page_embeddings"].shape))
|
185 |
|
186 |
|
187 |
with st.container(border=True):
|
188 |
query = st.text_area("query")
|
189 |
+
|
190 |
+
prompt_template_default = """Your goal is to answer queries based on the provided images. Each image is one page from a single PDF document. Provide answers that are at least 3 sentences long. Clearly explain the reasoning behind your answer. Create trustworthy answers by referencing the material in the PDF pages. Do not reference page numbers unless they appear on the page images.
|
191 |
+
|
192 |
+
---
|
193 |
+
|
194 |
+
{query}"""
|
195 |
+
|
196 |
+
with st.expander("Prompt Template"):
|
197 |
+
prompt_template = st.text_area(
|
198 |
+
"Customize the prompt template",
|
199 |
+
prompt_template_default,
|
200 |
+
height=200,
|
201 |
+
)
|
202 |
+
|
203 |
+
top_k = st.slider(
|
204 |
+
"num pages to retrieve", min_value=1, max_value=8, value=3, step=1
|
205 |
+
)
|
206 |
if st.button("answer query"):
|
207 |
SS["query_embeddings"] = embed_query_texts(
|
208 |
SS["colpali_model"],
|
|
|
214 |
for ipage in range(len(SS["page_embeddings"])):
|
215 |
# for every query token find the max_sim with every page patch
|
216 |
patch_query_scores = np.dot(
|
217 |
+
SS["page_embeddings"][ipage],
|
218 |
SS["query_embeddings"].T,
|
219 |
)
|
220 |
max_sim_score = patch_query_scores.max(axis=0).sum()
|
|
|
224 |
i_ranked_pages = np.argsort(-page_query_scores)
|
225 |
|
226 |
page_images = []
|
227 |
+
page_scores = []
|
228 |
+
num_pages = len(SS["page_images"])
|
229 |
+
for ii in range(min(top_k, num_pages)):
|
230 |
page_images.append(SS["page_images"][i_ranked_pages[ii]])
|
231 |
+
page_scores.append(page_query_scores[i_ranked_pages[ii]])
|
232 |
SS["retrieved_page_images"] = page_images
|
233 |
+
SS["retrieved_page_scores"] = page_scores
|
234 |
|
235 |
+
prompt = [prompt_template.format(query=query)] + page_images
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
genai.configure(api_key=st.secrets["google_genai_api_key"])
|
238 |
+
# genai_model_name = "gemini-1.5-flash"
|
239 |
genai_model_name = "gemini-1.5-pro"
|
240 |
gen_model = genai.GenerativeModel(
|
241 |
model_name=genai_model_name,
|
242 |
generation_config=genai.GenerationConfig(
|
243 |
+
temperature=0.0,
|
244 |
),
|
245 |
)
|
246 |
response = gen_model.generate_content(prompt)
|
|
|
249 |
|
250 |
|
251 |
if SS["response"] is not None:
|
252 |
+
st.header("Response")
|
253 |
st.write(SS["response"])
|
254 |
st.header("Retrieved Pages")
|
255 |
+
display_pages(
|
256 |
+
SS["retrieved_page_images"],
|
257 |
+
"retrieved_pages",
|
258 |
+
captions=[f"Score={el:.2f}" for el in SS["retrieved_page_scores"]],
|
259 |
+
)
|
260 |
|
261 |
|
262 |
if SS["page_images"] is not None:
|
263 |
+
st.header("All Pages")
|
264 |
display_pages(SS["page_images"], "all_pages")
|