annikwag's picture
Update app.py
d7e675b verified
raw
history blame
26.8 kB
import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# ------------------------------------------------------------------------------
# Import modules from the appStore package
# These modules handle data preparation, embedding, search, region handling,
# retrieval of RAG answers, and filtering utilities.
# ------------------------------------------------------------------------------
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# Note: The TF-IDF extraction is currently not used in the app.
# from appStore.tfidf_extraction import extract_top_keywords
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title,
format_project_id
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
# ------------------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------------------
# Read model parameters from configuration file
config = configparser.ConfigParser()
config.read('model_params.cfg')
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
# Set page configuration for Streamlit
st.set_page_config(page_title="SEARCH IATI", layout='wide')
# ------------------------------------------------------------------------------
# Load and Cache Project Data
# ------------------------------------------------------------------------------
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide project data.
Returns:
pd.DataFrame: Processed project data as a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# ------------------------------------------------------------------------------
# Calculate Budget Range (in million euros)
# ------------------------------------------------------------------------------
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
# ------------------------------------------------------------------------------
# Prepare Region Data
# ------------------------------------------------------------------------------
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
# ------------------------------------------------------------------------------
# Determine Device for Computation
# ------------------------------------------------------------------------------
device = 'cuda' if cuda.is_available() else 'cpu'
# ------------------------------------------------------------------------------
# Layout: Header and About Section
# ------------------------------------------------------------------------------
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This prototype app uses publicly available project data from the German
International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be incorrect or misleading.
""", unsafe_allow_html=True
)
# ------------------------------------------------------------------------------
# Create or Load the Embeddings Collection
# ------------------------------------------------------------------------------
collection_name = "giz_worldwide"
client = get_client()
# Display existing collections for debugging purposes
print(client.get_collections())
# Uncomment the block below if you need to reprocess and embed documents.
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
# Retrieve maximum project end year and region mapping
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build mapping between country names and region codes
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
# ------------------------------------------------------------------------------
# Session State Reset Functions
# ------------------------------------------------------------------------------
def reset_filters():
"""
Reset all filter options in the session state to their default values.
"""
st.session_state["region_filter"] = "All/Not allocated"
st.session_state["country_filter"] = "All/Not allocated"
current_year = datetime.now().year
default_start_year = current_year - 4
st.session_state["end_year_range"] = (default_start_year, max_end_year)
st.session_state["crs_filter"] = "All/Not allocated"
st.session_state["min_budget"] = min_budget_val
st.session_state["client_filter"] = "All/Not allocated"
st.session_state["query"] = ""
st.session_state["show_exact_matches"] = False
st.session_state["page"] = 1
def reset_page():
"""
Reset the pagination page to the first page.
"""
st.session_state.page = 1
# ------------------------------------------------------------------------------
# Main Query Input
# ------------------------------------------------------------------------------
var = st.text_input("Enter Question", key="query", on_change=reset_page)
# ------------------------------------------------------------------------------
# Filter Controls - Row 1: Basic Filters
# ------------------------------------------------------------------------------
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions),
key="region_filter", on_change=reset_page)
# If a specific region is selected, filter the country names accordingly.
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names,
key="country_filter", on_change=reset_page)
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year),
key="end_year_range",
on_change=reset_page
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options, key="crs_filter", on_change=reset_page)
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val,
key="min_budget",
on_change=reset_page
)
# ------------------------------------------------------------------------------
# Filter Controls - Row 2: Additional Filters
# ------------------------------------------------------------------------------
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
with col1_2:
client_options = sorted(project_data["client"].dropna().unique().tolist())
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
# Columns 2 to 5 are left empty for layout alignment
with col2_2:
st.empty()
with col3_2:
st.empty()
with col4_2:
st.empty()
with col5_2:
st.empty()
# ------------------------------------------------------------------------------
# Filter Controls - Row 3: Toggle and Reset Button
# ------------------------------------------------------------------------------
col_left, col_right = st.columns([11, 1])
with col_left:
# Checkbox to toggle exact match filtering
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches", on_change=reset_page)
with col_right:
# Reset filters button (right-aligned)
with st.container():
st.markdown("<div style='text-align: right;'>", unsafe_allow_html=True)
if st.button("**Reset Filters**", key="reset_button_row3"):
reset_filters()
st.markdown("</div>", unsafe_allow_html=True)
# ------------------------------------------------------------------------------
# Helper Function: Validate Project ID
# ------------------------------------------------------------------------------
def valid_project_id(pid_str):
"""
Check if the provided project ID string is valid.
Args:
pid_str (str): The project ID string.
Returns:
bool: True if the project ID is valid, False otherwise.
"""
if not pid_str:
return False
if pid_str.lower() in ["nan", "none"]:
return False
return True
# ------------------------------------------------------------------------------
# Main Search and Display Logic
# ------------------------------------------------------------------------------
if not var.strip():
# Inform the user if no query is entered.
st.info("Please enter a question to see results.")
else:
# --- 1. Execute Hybrid Search ---
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out results with very short page content
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply a threshold to semantic search scores if needed
semantic_thresholded = [r for r in semantic_all if r.score >= 0.25]
# --- 2. Apply User-Selected Filters ---
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Additional filtering by client if selected
if client_filter != "All/Not allocated":
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
# Remove duplicate entries from the results
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
"""
Format a numerical value as a currency string in euros.
Args:
value: The value to format.
Returns:
str: Formatted currency string.
"""
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# --- Reprint the user query for clarity ---
st.markdown(
f"<div style='text-align: left; font-size:2.1em; font-style: italic; font-weight: bold;'>Query: {var}</div>",
unsafe_allow_html=True
)
# --- 3. Display Search Results Based on Matching Mode ---
# Lexical (Exact Match) Search Results Branch
if show_exact_matches:
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in filtered_lexical
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical_no_dupe = remove_duplicates(lexical_substring_filtered)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe # Use all matching lexical results
# --- Pagination Setup ---
page_size = 15
total_results = len(top_results)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Display current page info
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
total_pages_str = f"<b>{total_pages}</b>"
col_title, col_pag = st.columns([13, 1])
with col_title:
st.markdown(
f"Showing **{total_results}** Lexical Search results (Page {page_num} of {total_pages_str})",
unsafe_allow_html=True
)
with col_pag:
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
index=current_page - 1, key="page_top")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
paged_results = top_results[start_index:end_index]
# Display each result with formatted metadata and content preview
for i, res in enumerate(paged_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
# Highlight query text in the title
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
title_clean = re.sub(r'<a.*?>|</a>', '', title_html)
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
# Prepare a description preview with an expandable "Show more" option
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Process predecessor and successor project IDs if available
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
parts = []
if valid_project_id(predecessor):
try:
formatted_pred = format_project_id(int(float(predecessor)))
except Exception:
formatted_pred = predecessor
parts.append(f"**Predecessor Project:** {formatted_pred}")
if valid_project_id(successor):
try:
formatted_succ = format_project_id(int(float(successor)))
except Exception:
formatted_succ = successor
parts.append(f"**Successor Project:** {formatted_succ}")
extra_line = " | ".join(parts) if parts else ""
# Build additional project information text
additional_text = (
f"**Objective:** {highlight_query(objective, var)}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b>"
)
if extra_line:
additional_text += f"<br>{extra_line}"
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
# Hide sensitive contact info if present
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget for lexical results
col_pag_bot = st.columns([11, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
index=st.session_state.page - 1, key="page_bot")
st.session_state.page = new_page_bot
# Semantic Search Results Branch
else:
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
page_size = 15
total_results = len(filtered_semantic_no_dupe)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
top_results = filtered_semantic_no_dupe[start_index:end_index]
# --- Retrieve and Format RAG Answer ---
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
bullet_lines = []
for line in rag_answer.splitlines():
if line.strip():
# Clean and format the RAG answer lines
line = re.sub(r'^[-*]\s+', '', line.strip())
line = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', line)
bullet_lines.append(f"<li>{line}</li>")
formatted_rag_answer = (
"<div style='background-color: #f0f0f0; padding: 10px;'>"
"<ul style='text-align: left; list-style-position: inside;'>"
+ "".join(bullet_lines) +
"</ul></div>"
)
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
st.divider()
# Pagination controls for semantic results
col_title, col_pag = st.columns([13, 1])
with col_title:
page_num = f"<b style='color: green;'>{current_page}</b>" if current_page != 1 else f"<b>{current_page}</b>"
total_pages_str = f"<b>{total_pages}</b>"
st.markdown(
f"Showing **{total_results}** Semantic Search results (Page {page_num} of {total_pages_str})",
unsafe_allow_html=True
)
with col_pag:
new_page_top = st.selectbox("Select Page", list(range(1, total_pages + 1)),
index=current_page - 1, key="page_top_sem")
st.session_state.page = new_page_top
# Display each semantic result with detailed metadata and preview
for i, res in enumerate(top_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_clean = re.sub(r'<a.*?>|</a>', '', metadata["title"])
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
parts = []
if valid_project_id(predecessor):
try:
formatted_pred = format_project_id(int(float(predecessor)))
except Exception:
formatted_pred = predecessor
parts.append(f"**Predecessor Project:** {formatted_pred}")
if valid_project_id(successor):
try:
formatted_succ = format_project_id(int(float(successor)))
except Exception:
formatted_succ = successor
parts.append(f"**Successor Project:** {formatted_succ}")
extra_line = " | ".join(parts) if parts else ""
additional_text = (
f"**Objective:** {metadata.get('objective', '')}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b>"
)
if extra_line:
additional_text += f"<br>{extra_line}"
additional_text += f"<br>**Country:** {country_raw}<br>**Sector:** {crs_combined}"
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget for semantic results
col_pag_bot = st.columns([13, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)),
index=st.session_state.page - 1, key="page_bot_sem")
st.session_state.page = new_page_bot