annikwag's picture
Update app.py
d7e675b verified
import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# ------------------------------------------------------------------------------
# Import modules from the appStore package
# These modules handle data preparation, embedding, search, region handling,
# retrieval of RAG answers, and filtering utilities.
# ------------------------------------------------------------------------------
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# Note: The TF-IDF extraction is currently not used in the app.
# from appStore.tfidf_extraction import extract_top_keywords
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title,
format_project_id
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
# ------------------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------------------
# Read model parameters from configuration file
config = configparser.ConfigParser()
config.read('model_params.cfg')
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
# Set page configuration for Streamlit
st.set_page_config(page_title="SEARCH IATI", layout='wide')
# ------------------------------------------------------------------------------
# Load and Cache Project Data
# ------------------------------------------------------------------------------
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide project data.
Returns:
pd.DataFrame: Processed project data as a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# ------------------------------------------------------------------------------
# Calculate Budget Range (in million euros)
# ------------------------------------------------------------------------------
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
# ------------------------------------------------------------------------------
# Prepare Region Data
# ------------------------------------------------------------------------------
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
# ------------------------------------------------------------------------------
# Determine Device for Computation
# ------------------------------------------------------------------------------
device = 'cuda' if cuda.is_available() else 'cpu'
# ------------------------------------------------------------------------------
# Layout: Header and About Section
# ------------------------------------------------------------------------------
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This prototype app uses publicly available project data from the German
International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be incorrect or misleading.
""", unsafe_allow_html=True
)
# ------------------------------------------------------------------------------
# Create or Load the Embeddings Collection
# ------------------------------------------------------------------------------
collection_name = "giz_worldwide"
client = get_client()
# Display existing collections for debugging purposes
print(client.get_collections())
# Uncomment the block below if you need to reprocess and embed documents.
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
# Retrieve maximum project end year and region mapping
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build mapping between country names and region codes
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
# ------------------------------------------------------------------------------
# Session State Reset Functions
# ------------------------------------------------------------------------------
def reset_filters():
"""
Reset all filter options in the session state to their default values.
"""
st.session_state["region_filter"] = "All/Not allocated"
st.session_state["country_filter"] = "All/Not allocated"
current_year = datetime.now().year
default_start_year = current_year - 4
st.session_state["end_year_range"] = (default_start_year, max_end_year)
st.session_state["crs_filter"] = "All/Not allocated"
st.session_state["min_budget"] = min_budget_val
st.session_state["client_filter"] = "All/Not allocated"
st.session_state["query"] = ""
st.session_state["show_exact_matches"] = False
st.session_state["page"] = 1
def reset_page():
"""
Reset the pagination page to the first page.
"""
st.session_state.page = 1
# ------------------------------------------------------------------------------
# Main Query Input
# ------------------------------------------------------------------------------
var = st.text_input("Enter Question", key="query", on_change=reset_page)
# ------------------------------------------------------------------------------
# Filter Controls - Row 1: Basic Filters
# ------------------------------------------------------------------------------
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions),
key="region_filter", on_change=reset_page)
# If a specific region is selected, filter the country names accordingly.
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names,
key="country_filter", on_change=reset_page)
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year),
key="end_year_range",
on_change=reset_page
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options, key="crs_filter", on_change=reset_page)
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val,
key="min_budget",
on_change=reset_page
)
# ------------------------------------------------------------------------------
# Filter Controls - Row 2: Additional Filters
# ------------------------------------------------------------------------------
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
with col1_2:
client_options = sorted(project_data["client"].dropna().unique().tolist())
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
# Columns 2 to 5 are left empty for layout alignment
with col2_2:
st.empty()
with col3_2:
st.empty()
with col4_2:
st.empty()
with col5_2:
st.empty()
# ------------------------------------------------------------------------------
# Filter Controls - Row 3: Toggle and Reset Button
# ------------------------------------------------------------------------------
col_left, col_right = st.columns([11, 1])
with col_left:
# Checkbox to toggle exact match filtering
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches", on_change=reset_page)
with col_right:
# Reset filters button (right-aligned)
with st.container():
st.markdown("<div style='text-align: right;'>", unsafe_allow_html=True)
if st.button("**Reset Filters**", key="reset_button_row3"):
reset_filters()
st.markdown("</div>", unsafe_allow_html=True)
# ------------------------------------------------------------------------------
# Helper Function: Validate Project ID
# ------------------------------------------------------------------------------
def valid_project_id(pid_str):
"""
Check if the provided project ID string is valid.
Args:
pid_str (str): The project ID string.
Returns:
bool: True if the project ID is valid, False otherwise.
"""
if not pid_str:
return False
if pid_str.lower() in ["nan", "none"]:
return False
return True
# ------------------------------------------------------------------------------
# Main Search and Display Logic
# ------------------------------------------------------------------------------
if not var.strip():
# Inform the user if no query is entered.
st.info("Please enter a question to see results.")
else:
# --- 1. Execute Hybrid Search ---
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out results with very short page content
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply a threshold to semantic search scores if needed
semantic_thresholded = [r for r in semantic_all if r.score >= 0.25]
# --- 2. Apply User-Selected Filters ---
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Additional filtering by client if selected
if client_filter != "All/Not allocated":
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
# Remove duplicate entries from the results
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
"""
Format a numerical value as a currency string in euros.
Args:
value: The value to format.
Returns:
str: Formatted currency string.
"""
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# --- Reprint the user query for clarity ---
st.markdown(
f"<div style='text-align: left; font-size:2.1em; font-style: italic; font-weight: bold;'>Query: {var}</div>",
unsafe_allow_html=True
)
# --- 3. Display Search Results Based on Matching Mode ---
# Lexical (Exact Match) Search Results Branch
if show_exact_matches:
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in filtered_lexical
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical_no_dupe = remove_duplicates(lexical_substring_filtered)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe # Use all matching lexical results
# --- Pagination Setup ---
page_size = 15
total_results = len(top_results)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Display current page info
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
total_pages_str = f"<b>{total_pages}</b>"
col_title, col_pag = st.columns([13, 1])
with col_title:
st.markdown(
f"Showing **{total_results}** Lexical Search results (Page {page_num} of {total_pages_str})",
unsafe_allow_html=True
)
with col_pag:
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
index=current_page - 1, key="page_top")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
paged_results = top_results[start_index:end_index]
# Display each result with formatted metadata and content preview
for i, res in enumerate(paged_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
# Highlight query text in the title
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
title_clean = re.sub(r'<a.*?>|</a>', '', title_html)
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
# Prepare a description preview with an expandable "Show more" option
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Process predecessor and successor project IDs if available
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
parts = []
if valid_project_id(predecessor):
try:
formatted_pred = format_project_id(int(float(predecessor)))
except Exception:
formatted_pred = predecessor
parts.append(f"**Predecessor Project:** {formatted_pred}")
if valid_project_id(successor):
try:
formatted_succ = format_project_id(int(float(successor)))
except Exception:
formatted_succ = successor
parts.append(f"**Successor Project:** {formatted_succ}")
extra_line = " | ".join(parts) if parts else ""
# Build additional project information text
additional_text = (
f"**Objective:** {highlight_query(objective, var)}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b>"
)
if extra_line:
additional_text += f"<br>{extra_line}"
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
# Hide sensitive contact info if present
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget for lexical results
col_pag_bot = st.columns([11, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
index=st.session_state.page - 1, key="page_bot")
st.session_state.page = new_page_bot
# Semantic Search Results Branch
else:
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
page_size = 15
total_results = len(filtered_semantic_no_dupe)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
top_results = filtered_semantic_no_dupe[start_index:end_index]
# --- Retrieve and Format RAG Answer ---
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
bullet_lines = []
for line in rag_answer.splitlines():
if line.strip():
# Clean and format the RAG answer lines
line = re.sub(r'^[-*]\s+', '', line.strip())
line = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', line)
bullet_lines.append(f"<li>{line}</li>")
formatted_rag_answer = (
"<div style='background-color: #f0f0f0; padding: 10px;'>"
"<ul style='text-align: left; list-style-position: inside;'>"
+ "".join(bullet_lines) +
"</ul></div>"
)
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
st.divider()
# Pagination controls for semantic results
col_title, col_pag = st.columns([13, 1])
with col_title:
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
total_pages_str = f"<b>{total_pages}</b>"
st.markdown(
f"Showing **{total_results}** Semantic Search results (Page {page_num} of {total_pages_str})",
unsafe_allow_html=True
)
with col_pag:
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
index=current_page - 1, key="page_top_sem")
st.session_state.page = new_page_top
# Display each semantic result with detailed metadata and preview
for i, res in enumerate(top_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_clean = re.sub(r'<a.*?>|</a>', '', metadata["title"])
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
parts = []
if valid_project_id(predecessor):
try:
formatted_pred = format_project_id(int(float(predecessor)))
except Exception:
formatted_pred = predecessor
parts.append(f"**Predecessor Project:** {formatted_pred}")
if valid_project_id(successor):
try:
formatted_succ = format_project_id(int(float(successor)))
except Exception:
formatted_succ = successor
parts.append(f"**Successor Project:** {formatted_succ}")
extra_line = " | ".join(parts) if parts else ""
additional_text = (
f"**Objective:** {metadata.get('objective', '')}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b>"
)
if extra_line:
additional_text += f"<br>{extra_line}"
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget for semantic results
col_pag_bot = st.columns([13, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
index=st.session_state.page - 1, key="page_bot_sem")
st.session_state.page = new_page_bot