Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,11 @@ import configparser
|
|
7 |
from datetime import datetime
|
8 |
from torch import cuda
|
9 |
|
10 |
-
#
|
|
|
|
|
|
|
|
|
11 |
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
|
12 |
from appStore.prep_utils import create_documents, get_client
|
13 |
from appStore.embed import hybrid_embed_chunks
|
@@ -19,10 +23,9 @@ from appStore.region_utils import (
|
|
19 |
get_regions,
|
20 |
get_country_name_and_region_mapping
|
21 |
)
|
22 |
-
# TF-IDF
|
23 |
# from appStore.tfidf_extraction import extract_top_keywords
|
24 |
|
25 |
-
# Import helper modules, including format_project_id for formatting IDs
|
26 |
from appStore.rag_utils import (
|
27 |
highlight_query,
|
28 |
get_rag_answer,
|
@@ -34,52 +37,57 @@ from appStore.filter_utils import (
|
|
34 |
filter_results,
|
35 |
get_crs_options
|
36 |
)
|
37 |
-
|
38 |
from appStore.crs_utils import lookup_crs_value
|
39 |
|
40 |
-
|
41 |
-
# Model
|
42 |
-
|
|
|
43 |
config = configparser.ConfigParser()
|
44 |
config.read('model_params.cfg')
|
45 |
-
|
46 |
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
|
47 |
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
|
48 |
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
|
49 |
|
|
|
50 |
st.set_page_config(page_title="SEARCH IATI", layout='wide')
|
51 |
|
52 |
-
|
53 |
-
# Cache
|
54 |
-
|
55 |
@st.cache_data
|
56 |
def load_project_data():
|
57 |
"""
|
58 |
-
Load and process the GIZ worldwide data
|
|
|
|
|
|
|
59 |
"""
|
60 |
return process_giz_worldwide()
|
61 |
|
62 |
project_data = load_project_data()
|
63 |
|
64 |
-
#
|
|
|
|
|
65 |
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
|
66 |
min_budget_val = float(budget_series.min() / 1e6)
|
67 |
max_budget_val = float(budget_series.max() / 1e6)
|
68 |
|
69 |
-
|
70 |
-
# Prepare
|
71 |
-
|
72 |
region_lookup_path = "docStore/regions_lookup.csv"
|
73 |
region_df = load_region_data(region_lookup_path)
|
74 |
|
75 |
-
|
76 |
-
#
|
77 |
-
|
78 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
79 |
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
col_title, col_about = st.columns([8, 2])
|
84 |
with col_title:
|
85 |
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
|
@@ -87,30 +95,32 @@ with col_about:
|
|
87 |
with st.expander("ℹ️ About"):
|
88 |
st.markdown(
|
89 |
"""
|
90 |
-
This
|
91 |
-
|
92 |
**Please do NOT enter sensitive or personal information.**
|
93 |
-
**Note**: The answers are AI-generated and may be
|
94 |
""", unsafe_allow_html=True
|
95 |
)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
#
|
100 |
-
###########################################
|
101 |
collection_name = "giz_worldwide"
|
102 |
client = get_client()
|
|
|
|
|
103 |
print(client.get_collections())
|
104 |
|
105 |
-
# Uncomment if
|
106 |
# chunks = process_giz_worldwide()
|
107 |
# temp_doc = create_documents(chunks, 'chunks')
|
108 |
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
|
109 |
|
|
|
110 |
max_end_year = get_max_end_year(client, collection_name)
|
111 |
_, unique_sub_regions = get_regions(region_df)
|
112 |
|
113 |
-
# Build country
|
114 |
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
|
115 |
client,
|
116 |
collection_name,
|
@@ -121,10 +131,13 @@ country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mappi
|
|
121 |
)
|
122 |
unique_country_names = sorted(country_name_mapping.keys())
|
123 |
|
124 |
-
|
125 |
-
#
|
126 |
-
|
127 |
def reset_filters():
|
|
|
|
|
|
|
128 |
st.session_state["region_filter"] = "All/Not allocated"
|
129 |
st.session_state["country_filter"] = "All/Not allocated"
|
130 |
current_year = datetime.now().year
|
@@ -138,19 +151,24 @@ def reset_filters():
|
|
138 |
st.session_state["page"] = 1
|
139 |
|
140 |
def reset_page():
|
|
|
|
|
|
|
141 |
st.session_state.page = 1
|
142 |
|
143 |
-
|
144 |
-
# Main
|
145 |
-
|
146 |
-
var = st.text_input("Enter Question", key="query",
|
147 |
-
###########################################
|
148 |
-
# Filter Controls - Row 1
|
149 |
-
###########################################
|
150 |
-
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
|
151 |
|
|
|
|
|
|
|
|
|
152 |
with col1:
|
153 |
-
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions),
|
|
|
|
|
154 |
if region_filter == "All/Not allocated":
|
155 |
filtered_country_names = unique_country_names
|
156 |
else:
|
@@ -160,7 +178,8 @@ else:
|
|
160 |
]
|
161 |
|
162 |
with col2:
|
163 |
-
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names,
|
|
|
164 |
|
165 |
with col3:
|
166 |
current_year = datetime.now().year
|
@@ -188,14 +207,14 @@ with col5:
|
|
188 |
on_change=reset_page
|
189 |
)
|
190 |
|
191 |
-
|
192 |
-
# Filter Controls - Row 2
|
193 |
-
|
194 |
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
|
195 |
-
|
196 |
with col1_2:
|
197 |
client_options = sorted(project_data["client"].dropna().unique().tolist())
|
198 |
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
|
|
|
199 |
with col2_2:
|
200 |
st.empty()
|
201 |
with col3_2:
|
@@ -205,51 +224,59 @@ with col4_2:
|
|
205 |
with col5_2:
|
206 |
st.empty()
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
#
|
211 |
-
###########################################
|
212 |
col_left, col_right = st.columns([11, 1])
|
213 |
with col_left:
|
214 |
-
#
|
215 |
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches", on_change=reset_page)
|
216 |
with col_right:
|
217 |
-
#
|
218 |
with st.container():
|
219 |
st.markdown("<div style='text-align: right;'>", unsafe_allow_html=True)
|
220 |
if st.button("**Reset Filters**", key="reset_button_row3"):
|
221 |
reset_filters()
|
222 |
st.markdown("</div>", unsafe_allow_html=True)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
#
|
227 |
-
###########################################
|
228 |
def valid_project_id(pid_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
if not pid_str:
|
230 |
return False
|
231 |
if pid_str.lower() in ["nan", "none"]:
|
232 |
return False
|
233 |
return True
|
234 |
|
235 |
-
|
236 |
-
# Main Search
|
237 |
-
|
238 |
if not var.strip():
|
|
|
239 |
st.info("Please enter a question to see results.")
|
240 |
else:
|
241 |
-
# 1
|
242 |
results = hybrid_search(client, var, collection_name, limit=500)
|
243 |
semantic_all, lexical_all = results[0], results[1]
|
244 |
|
245 |
-
# Filter out short
|
246 |
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
|
247 |
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
|
248 |
|
249 |
-
# Apply threshold to semantic
|
250 |
semantic_thresholded = [r for r in semantic_all if r.score >= 0.25]
|
251 |
|
252 |
-
# 2
|
253 |
filtered_semantic = filter_results(
|
254 |
semantic_thresholded,
|
255 |
country_filter,
|
@@ -275,29 +302,38 @@ else:
|
|
275 |
get_country_name
|
276 |
)
|
277 |
|
278 |
-
# Additional
|
279 |
if client_filter != "All/Not allocated":
|
280 |
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
|
281 |
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
|
282 |
|
283 |
-
# Remove
|
284 |
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
|
285 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
286 |
|
287 |
def format_currency(value):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
try:
|
289 |
return f"€{int(float(value)):,}"
|
290 |
except (ValueError, TypeError):
|
291 |
return value
|
292 |
|
293 |
-
# --- Reprint
|
294 |
st.markdown(
|
295 |
f"<div style='text-align: left; font-size:2.1em; font-style: italic; font-weight: bold;'>Query: {var}</div>",
|
296 |
unsafe_allow_html=True
|
297 |
)
|
298 |
|
299 |
-
# 3
|
300 |
-
#
|
301 |
if show_exact_matches:
|
302 |
query_substring = var.strip().lower()
|
303 |
lexical_substring_filtered = [
|
@@ -308,21 +344,19 @@ else:
|
|
308 |
if not filtered_lexical_no_dupe:
|
309 |
st.write('No exact matches, consider unchecking "Show only exact matches"')
|
310 |
else:
|
311 |
-
top_results = filtered_lexical_no_dupe #
|
312 |
-
|
313 |
-
# --- Pagination
|
314 |
page_size = 15
|
315 |
total_results = len(top_results)
|
316 |
total_pages = (total_results - 1) // page_size + 1
|
317 |
if "page" not in st.session_state:
|
318 |
st.session_state.page = 1
|
319 |
current_page = st.session_state.page
|
320 |
-
|
321 |
-
#
|
322 |
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
|
323 |
total_pages_str = f"<b>{total_pages}</b>"
|
324 |
-
|
325 |
-
# Create two columns: one for the title and one for the selectbox
|
326 |
col_title, col_pag = st.columns([13, 1])
|
327 |
with col_title:
|
328 |
st.markdown(
|
@@ -333,20 +367,22 @@ else:
|
|
333 |
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
|
334 |
index=current_page - 1, key="page_top")
|
335 |
st.session_state.page = new_page_top
|
336 |
-
|
337 |
start_index = (st.session_state.page - 1) * page_size
|
338 |
end_index = start_index + page_size
|
339 |
paged_results = top_results[start_index:end_index]
|
340 |
|
|
|
341 |
for i, res in enumerate(paged_results, start=start_index+1):
|
342 |
metadata = res.payload.get('metadata', {})
|
343 |
if "title" not in metadata:
|
344 |
metadata["title"] = compute_title(metadata)
|
|
|
345 |
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
|
346 |
title_clean = re.sub(r'<a.*?>|</a>', '', title_html)
|
347 |
-
# Prepend the result number
|
348 |
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
|
349 |
|
|
|
350 |
objective = metadata.get("objective", "None")
|
351 |
desc_en = metadata.get("description.en", "").strip()
|
352 |
desc_de = metadata.get("description.de", "").strip()
|
@@ -377,7 +413,8 @@ else:
|
|
377 |
new_crs_value = lookup_crs_value(crs_key_clean)
|
378 |
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
|
379 |
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
|
380 |
-
|
|
|
381 |
predecessor = metadata.get("predecessor_id", "").strip()
|
382 |
successor = metadata.get("successor_id", "").strip()
|
383 |
parts = []
|
@@ -394,8 +431,8 @@ else:
|
|
394 |
formatted_succ = successor
|
395 |
parts.append(f"**Successor Project:** {formatted_succ}")
|
396 |
extra_line = " | ".join(parts) if parts else ""
|
397 |
-
|
398 |
-
# Build additional
|
399 |
additional_text = (
|
400 |
f"**Objective:** {highlight_query(objective, var)}<br>"
|
401 |
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
|
@@ -406,15 +443,17 @@ else:
|
|
406 |
additional_text += f"<br>{extra_line}"
|
407 |
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
|
408 |
|
|
|
409 |
contact = metadata.get("contact", "").strip()
|
410 |
if contact and contact.lower() != "[email protected]":
|
411 |
additional_text += f"<br>**Contact:** [email protected]"
|
412 |
st.markdown(additional_text, unsafe_allow_html=True)
|
413 |
st.divider()
|
414 |
|
415 |
-
# Bottom pagination widget
|
416 |
col_pag_bot = st.columns([11, 1])[1]
|
417 |
-
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
|
|
|
418 |
st.session_state.page = new_page_bot
|
419 |
|
420 |
# Semantic Search Results Branch
|
@@ -433,17 +472,15 @@ else:
|
|
433 |
start_index = (st.session_state.page - 1) * page_size
|
434 |
end_index = start_index + page_size
|
435 |
top_results = filtered_semantic_no_dupe[start_index:end_index]
|
436 |
-
|
|
|
437 |
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
|
438 |
bullet_lines = []
|
439 |
for line in rag_answer.splitlines():
|
440 |
if line.strip():
|
441 |
-
#
|
442 |
line = re.sub(r'^[-*]\s+', '', line.strip())
|
443 |
-
# Convert markdown bold (e.g. **Title [2018.2101.6]**) to HTML <b> tags
|
444 |
line = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', line)
|
445 |
-
# Optionally, bold any standalone numbers (if desired)
|
446 |
-
# line = re.sub(r'(\d+)', r'<b>\1</b>', line)
|
447 |
bullet_lines.append(f"<li>{line}</li>")
|
448 |
formatted_rag_answer = (
|
449 |
"<div style='background-color: #f0f0f0; padding: 10px;'>"
|
@@ -454,10 +491,9 @@ else:
|
|
454 |
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
|
455 |
|
456 |
st.divider()
|
457 |
-
#
|
458 |
col_title, col_pag = st.columns([13, 1])
|
459 |
with col_title:
|
460 |
-
# Format page numbers with bold formatting (green if not the first page)
|
461 |
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
|
462 |
total_pages_str = f"<b>{total_pages}</b>"
|
463 |
st.markdown(
|
@@ -465,9 +501,11 @@ else:
|
|
465 |
unsafe_allow_html=True
|
466 |
)
|
467 |
with col_pag:
|
468 |
-
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
|
|
|
469 |
st.session_state.page = new_page_top
|
470 |
|
|
|
471 |
for i, res in enumerate(top_results, start=start_index+1):
|
472 |
metadata = res.payload.get('metadata', {})
|
473 |
if "title" not in metadata:
|
@@ -539,7 +577,8 @@ else:
|
|
539 |
st.markdown(additional_text, unsafe_allow_html=True)
|
540 |
st.divider()
|
541 |
|
542 |
-
# Bottom pagination widget
|
543 |
col_pag_bot = st.columns([13, 1])[1]
|
544 |
-
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
|
545 |
-
|
|
|
|
7 |
from datetime import datetime
|
8 |
from torch import cuda
|
9 |
|
10 |
+
# ------------------------------------------------------------------------------
|
11 |
+
# Import modules from the appStore package
|
12 |
+
# These modules handle data preparation, embedding, search, region handling,
|
13 |
+
# retrieval of RAG answers, and filtering utilities.
|
14 |
+
# ------------------------------------------------------------------------------
|
15 |
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
|
16 |
from appStore.prep_utils import create_documents, get_client
|
17 |
from appStore.embed import hybrid_embed_chunks
|
|
|
23 |
get_regions,
|
24 |
get_country_name_and_region_mapping
|
25 |
)
|
26 |
+
# Note: The TF-IDF extraction is currently not used in the app.
|
27 |
# from appStore.tfidf_extraction import extract_top_keywords
|
28 |
|
|
|
29 |
from appStore.rag_utils import (
|
30 |
highlight_query,
|
31 |
get_rag_answer,
|
|
|
37 |
filter_results,
|
38 |
get_crs_options
|
39 |
)
|
|
|
40 |
from appStore.crs_utils import lookup_crs_value
|
41 |
|
42 |
+
# ------------------------------------------------------------------------------
|
43 |
+
# Model Configuration
|
44 |
+
# ------------------------------------------------------------------------------
|
45 |
+
# Read model parameters from configuration file
|
46 |
config = configparser.ConfigParser()
|
47 |
config.read('model_params.cfg')
|
|
|
48 |
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
|
49 |
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
|
50 |
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
|
51 |
|
52 |
+
# Set page configuration for Streamlit
|
53 |
st.set_page_config(page_title="SEARCH IATI", layout='wide')
|
54 |
|
55 |
+
# ------------------------------------------------------------------------------
|
56 |
+
# Load and Cache Project Data
|
57 |
+
# ------------------------------------------------------------------------------
|
58 |
@st.cache_data
|
59 |
def load_project_data():
|
60 |
"""
|
61 |
+
Load and process the GIZ worldwide project data.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
pd.DataFrame: Processed project data as a pandas DataFrame.
|
65 |
"""
|
66 |
return process_giz_worldwide()
|
67 |
|
68 |
project_data = load_project_data()
|
69 |
|
70 |
+
# ------------------------------------------------------------------------------
|
71 |
+
# Calculate Budget Range (in million euros)
|
72 |
+
# ------------------------------------------------------------------------------
|
73 |
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
|
74 |
min_budget_val = float(budget_series.min() / 1e6)
|
75 |
max_budget_val = float(budget_series.max() / 1e6)
|
76 |
|
77 |
+
# ------------------------------------------------------------------------------
|
78 |
+
# Prepare Region Data
|
79 |
+
# ------------------------------------------------------------------------------
|
80 |
region_lookup_path = "docStore/regions_lookup.csv"
|
81 |
region_df = load_region_data(region_lookup_path)
|
82 |
|
83 |
+
# ------------------------------------------------------------------------------
|
84 |
+
# Determine Device for Computation
|
85 |
+
# ------------------------------------------------------------------------------
|
86 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
87 |
|
88 |
+
# ------------------------------------------------------------------------------
|
89 |
+
# Layout: Header and About Section
|
90 |
+
# ------------------------------------------------------------------------------
|
91 |
col_title, col_about = st.columns([8, 2])
|
92 |
with col_title:
|
93 |
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
|
|
|
95 |
with st.expander("ℹ️ About"):
|
96 |
st.markdown(
|
97 |
"""
|
98 |
+
This prototype app uses publicly available project data from the German
|
99 |
+
International Cooperation Society (GIZ) as of 23rd February 2025.
|
100 |
**Please do NOT enter sensitive or personal information.**
|
101 |
+
**Note**: The answers are AI-generated and may be incorrect or misleading.
|
102 |
""", unsafe_allow_html=True
|
103 |
)
|
104 |
|
105 |
+
# ------------------------------------------------------------------------------
|
106 |
+
# Create or Load the Embeddings Collection
|
107 |
+
# ------------------------------------------------------------------------------
|
|
|
108 |
collection_name = "giz_worldwide"
|
109 |
client = get_client()
|
110 |
+
|
111 |
+
# Display existing collections for debugging purposes
|
112 |
print(client.get_collections())
|
113 |
|
114 |
+
# Uncomment the block below if you need to reprocess and embed documents.
|
115 |
# chunks = process_giz_worldwide()
|
116 |
# temp_doc = create_documents(chunks, 'chunks')
|
117 |
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
|
118 |
|
119 |
+
# Retrieve maximum project end year and region mapping
|
120 |
max_end_year = get_max_end_year(client, collection_name)
|
121 |
_, unique_sub_regions = get_regions(region_df)
|
122 |
|
123 |
+
# Build mapping between country names and region codes
|
124 |
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
|
125 |
client,
|
126 |
collection_name,
|
|
|
131 |
)
|
132 |
unique_country_names = sorted(country_name_mapping.keys())
|
133 |
|
134 |
+
# ------------------------------------------------------------------------------
|
135 |
+
# Session State Reset Functions
|
136 |
+
# ------------------------------------------------------------------------------
|
137 |
def reset_filters():
|
138 |
+
"""
|
139 |
+
Reset all filter options in the session state to their default values.
|
140 |
+
"""
|
141 |
st.session_state["region_filter"] = "All/Not allocated"
|
142 |
st.session_state["country_filter"] = "All/Not allocated"
|
143 |
current_year = datetime.now().year
|
|
|
151 |
st.session_state["page"] = 1
|
152 |
|
153 |
def reset_page():
|
154 |
+
"""
|
155 |
+
Reset the pagination page to the first page.
|
156 |
+
"""
|
157 |
st.session_state.page = 1
|
158 |
|
159 |
+
# ------------------------------------------------------------------------------
|
160 |
+
# Main Query Input
|
161 |
+
# ------------------------------------------------------------------------------
|
162 |
+
var = st.text_input("Enter Question", key="query", on_change=reset_page)
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
# ------------------------------------------------------------------------------
|
165 |
+
# Filter Controls - Row 1: Basic Filters
|
166 |
+
# ------------------------------------------------------------------------------
|
167 |
+
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
|
168 |
with col1:
|
169 |
+
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions),
|
170 |
+
key="region_filter", on_change=reset_page)
|
171 |
+
# If a specific region is selected, filter the country names accordingly.
|
172 |
if region_filter == "All/Not allocated":
|
173 |
filtered_country_names = unique_country_names
|
174 |
else:
|
|
|
178 |
]
|
179 |
|
180 |
with col2:
|
181 |
+
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names,
|
182 |
+
key="country_filter", on_change=reset_page)
|
183 |
|
184 |
with col3:
|
185 |
current_year = datetime.now().year
|
|
|
207 |
on_change=reset_page
|
208 |
)
|
209 |
|
210 |
+
# ------------------------------------------------------------------------------
|
211 |
+
# Filter Controls - Row 2: Additional Filters
|
212 |
+
# ------------------------------------------------------------------------------
|
213 |
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
|
|
|
214 |
with col1_2:
|
215 |
client_options = sorted(project_data["client"].dropna().unique().tolist())
|
216 |
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
|
217 |
+
# Columns 2 to 5 are left empty for layout alignment
|
218 |
with col2_2:
|
219 |
st.empty()
|
220 |
with col3_2:
|
|
|
224 |
with col5_2:
|
225 |
st.empty()
|
226 |
|
227 |
+
# ------------------------------------------------------------------------------
|
228 |
+
# Filter Controls - Row 3: Toggle and Reset Button
|
229 |
+
# ------------------------------------------------------------------------------
|
|
|
230 |
col_left, col_right = st.columns([11, 1])
|
231 |
with col_left:
|
232 |
+
# Checkbox to toggle exact match filtering
|
233 |
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches", on_change=reset_page)
|
234 |
with col_right:
|
235 |
+
# Reset filters button (right-aligned)
|
236 |
with st.container():
|
237 |
st.markdown("<div style='text-align: right;'>", unsafe_allow_html=True)
|
238 |
if st.button("**Reset Filters**", key="reset_button_row3"):
|
239 |
reset_filters()
|
240 |
st.markdown("</div>", unsafe_allow_html=True)
|
241 |
|
242 |
+
# ------------------------------------------------------------------------------
|
243 |
+
# Helper Function: Validate Project ID
|
244 |
+
# ------------------------------------------------------------------------------
|
|
|
245 |
def valid_project_id(pid_str):
|
246 |
+
"""
|
247 |
+
Check if the provided project ID string is valid.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
pid_str (str): The project ID string.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
bool: True if the project ID is valid, False otherwise.
|
254 |
+
"""
|
255 |
if not pid_str:
|
256 |
return False
|
257 |
if pid_str.lower() in ["nan", "none"]:
|
258 |
return False
|
259 |
return True
|
260 |
|
261 |
+
# ------------------------------------------------------------------------------
|
262 |
+
# Main Search and Display Logic
|
263 |
+
# ------------------------------------------------------------------------------
|
264 |
if not var.strip():
|
265 |
+
# Inform the user if no query is entered.
|
266 |
st.info("Please enter a question to see results.")
|
267 |
else:
|
268 |
+
# --- 1. Execute Hybrid Search ---
|
269 |
results = hybrid_search(client, var, collection_name, limit=500)
|
270 |
semantic_all, lexical_all = results[0], results[1]
|
271 |
|
272 |
+
# Filter out results with very short page content
|
273 |
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
|
274 |
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
|
275 |
|
276 |
+
# Apply a threshold to semantic search scores if needed
|
277 |
semantic_thresholded = [r for r in semantic_all if r.score >= 0.25]
|
278 |
|
279 |
+
# --- 2. Apply User-Selected Filters ---
|
280 |
filtered_semantic = filter_results(
|
281 |
semantic_thresholded,
|
282 |
country_filter,
|
|
|
302 |
get_country_name
|
303 |
)
|
304 |
|
305 |
+
# Additional filtering by client if selected
|
306 |
if client_filter != "All/Not allocated":
|
307 |
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
|
308 |
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
|
309 |
|
310 |
+
# Remove duplicate entries from the results
|
311 |
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
|
312 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
313 |
|
314 |
def format_currency(value):
|
315 |
+
"""
|
316 |
+
Format a numerical value as a currency string in euros.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
value: The value to format.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
str: Formatted currency string.
|
323 |
+
"""
|
324 |
try:
|
325 |
return f"€{int(float(value)):,}"
|
326 |
except (ValueError, TypeError):
|
327 |
return value
|
328 |
|
329 |
+
# --- Reprint the user query for clarity ---
|
330 |
st.markdown(
|
331 |
f"<div style='text-align: left; font-size:2.1em; font-style: italic; font-weight: bold;'>Query: {var}</div>",
|
332 |
unsafe_allow_html=True
|
333 |
)
|
334 |
|
335 |
+
# --- 3. Display Search Results Based on Matching Mode ---
|
336 |
+
# Lexical (Exact Match) Search Results Branch
|
337 |
if show_exact_matches:
|
338 |
query_substring = var.strip().lower()
|
339 |
lexical_substring_filtered = [
|
|
|
344 |
if not filtered_lexical_no_dupe:
|
345 |
st.write('No exact matches, consider unchecking "Show only exact matches"')
|
346 |
else:
|
347 |
+
top_results = filtered_lexical_no_dupe # Use all matching lexical results
|
348 |
+
|
349 |
+
# --- Pagination Setup ---
|
350 |
page_size = 15
|
351 |
total_results = len(top_results)
|
352 |
total_pages = (total_results - 1) // page_size + 1
|
353 |
if "page" not in st.session_state:
|
354 |
st.session_state.page = 1
|
355 |
current_page = st.session_state.page
|
356 |
+
|
357 |
+
# Display current page info
|
358 |
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
|
359 |
total_pages_str = f"<b>{total_pages}</b>"
|
|
|
|
|
360 |
col_title, col_pag = st.columns([13, 1])
|
361 |
with col_title:
|
362 |
st.markdown(
|
|
|
367 |
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
|
368 |
index=current_page - 1, key="page_top")
|
369 |
st.session_state.page = new_page_top
|
370 |
+
|
371 |
start_index = (st.session_state.page - 1) * page_size
|
372 |
end_index = start_index + page_size
|
373 |
paged_results = top_results[start_index:end_index]
|
374 |
|
375 |
+
# Display each result with formatted metadata and content preview
|
376 |
for i, res in enumerate(paged_results, start=start_index+1):
|
377 |
metadata = res.payload.get('metadata', {})
|
378 |
if "title" not in metadata:
|
379 |
metadata["title"] = compute_title(metadata)
|
380 |
+
# Highlight query text in the title
|
381 |
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
|
382 |
title_clean = re.sub(r'<a.*?>|</a>', '', title_html)
|
|
|
383 |
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
|
384 |
|
385 |
+
# Prepare a description preview with an expandable "Show more" option
|
386 |
objective = metadata.get("objective", "None")
|
387 |
desc_en = metadata.get("description.en", "").strip()
|
388 |
desc_de = metadata.get("description.de", "").strip()
|
|
|
413 |
new_crs_value = lookup_crs_value(crs_key_clean)
|
414 |
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
|
415 |
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
|
416 |
+
|
417 |
+
# Process predecessor and successor project IDs if available
|
418 |
predecessor = metadata.get("predecessor_id", "").strip()
|
419 |
successor = metadata.get("successor_id", "").strip()
|
420 |
parts = []
|
|
|
431 |
formatted_succ = successor
|
432 |
parts.append(f"**Successor Project:** {formatted_succ}")
|
433 |
extra_line = " | ".join(parts) if parts else ""
|
434 |
+
|
435 |
+
# Build additional project information text
|
436 |
additional_text = (
|
437 |
f"**Objective:** {highlight_query(objective, var)}<br>"
|
438 |
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
|
|
|
443 |
additional_text += f"<br>{extra_line}"
|
444 |
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
|
445 |
|
446 |
+
# Hide sensitive contact info if present
|
447 |
contact = metadata.get("contact", "").strip()
|
448 |
if contact and contact.lower() != "[email protected]":
|
449 |
additional_text += f"<br>**Contact:** [email protected]"
|
450 |
st.markdown(additional_text, unsafe_allow_html=True)
|
451 |
st.divider()
|
452 |
|
453 |
+
# Bottom pagination widget for lexical results
|
454 |
col_pag_bot = st.columns([11, 1])[1]
|
455 |
+
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
|
456 |
+
index=st.session_state.page - 1, key="page_bot")
|
457 |
st.session_state.page = new_page_bot
|
458 |
|
459 |
# Semantic Search Results Branch
|
|
|
472 |
start_index = (st.session_state.page - 1) * page_size
|
473 |
end_index = start_index + page_size
|
474 |
top_results = filtered_semantic_no_dupe[start_index:end_index]
|
475 |
+
|
476 |
+
# --- Retrieve and Format RAG Answer ---
|
477 |
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
|
478 |
bullet_lines = []
|
479 |
for line in rag_answer.splitlines():
|
480 |
if line.strip():
|
481 |
+
# Clean and format the RAG answer lines
|
482 |
line = re.sub(r'^[-*]\s+', '', line.strip())
|
|
|
483 |
line = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', line)
|
|
|
|
|
484 |
bullet_lines.append(f"<li>{line}</li>")
|
485 |
formatted_rag_answer = (
|
486 |
"<div style='background-color: #f0f0f0; padding: 10px;'>"
|
|
|
491 |
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
|
492 |
|
493 |
st.divider()
|
494 |
+
# Pagination controls for semantic results
|
495 |
col_title, col_pag = st.columns([13, 1])
|
496 |
with col_title:
|
|
|
497 |
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
|
498 |
total_pages_str = f"<b>{total_pages}</b>"
|
499 |
st.markdown(
|
|
|
501 |
unsafe_allow_html=True
|
502 |
)
|
503 |
with col_pag:
|
504 |
+
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
|
505 |
+
index=current_page - 1, key="page_top_sem")
|
506 |
st.session_state.page = new_page_top
|
507 |
|
508 |
+
# Display each semantic result with detailed metadata and preview
|
509 |
for i, res in enumerate(top_results, start=start_index+1):
|
510 |
metadata = res.payload.get('metadata', {})
|
511 |
if "title" not in metadata:
|
|
|
577 |
st.markdown(additional_text, unsafe_allow_html=True)
|
578 |
st.divider()
|
579 |
|
580 |
+
# Bottom pagination widget for semantic results
|
581 |
col_pag_bot = st.columns([13, 1])[1]
|
582 |
+
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
|
583 |
+
index=st.session_state.page - 1, key="page_bot_sem")
|
584 |
+
st.session_state.page = new_page_bot
|