import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# Import existing modules from appStore
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# TF-IDF part (excluded from the app for now)
# from appStore.tfidf_extraction import extract_top_keywords
# Import helper modules
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
###########################################
# Model Config
###########################################
# Initialize the parser and read the configuration file
config = configparser.ConfigParser()
config.read('model_params.cfg')
# Retrieve model parameters from the "MODEL" section
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
WRITE_ACCESS_TOKEN = config.get('MODEL', 'WRITE_ACCESS_TOKEN')
st.set_page_config(page_title="SEARCH IATI", layout='wide')
###########################################
# Cache the project data
###########################################
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide data, returning a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# Determine min and max budgets in million euros
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
###########################################
# Prepare region data
###########################################
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
###########################################
# Get device
###########################################
device = 'cuda' if cuda.is_available() else 'cpu'
###########################################
# Streamlit App Layout
###########################################
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("
GIZ Project Search (PROTOTYPE)
", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This app is a prototype for testing purposes using publicly available project data
from the German International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be wrong or misleading.
""", unsafe_allow_html=True
)
# Main query input
var = st.text_input("Enter Question")
###########################################
# Create or load the embeddings collection
###########################################
collection_name = "giz_worldwide"
client = get_client()
print(client.get_collections())
# If needed, once only:
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build country->code and code->region mapping
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
###########################################
# Filter Controls
###########################################
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year)
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options)
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val
)
show_exact_matches = st.checkbox("Show only exact matches", value=False)
###########################################
# Main Search / Results
###########################################
if not var.strip():
st.info("Please enter a question to see results.")
else:
# 1) Perform hybrid search
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out short pages
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply threshold to semantic results if desired
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
# 2) Filter results based on the user’s selections
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Remove duplicates
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
"""
Format a numeric or string value as currency (EUR) with commas.
"""
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# 3) Display results
if show_exact_matches:
# Lexical substring match only
st.write("Showing **Top 15 Lexical Search results**")
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in lexical_all
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical = filter_results(
lexical_substring_filtered,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe[:10]
# RAG answer
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
st.markdown(f"{var}
", unsafe_allow_html=True)
st.write(rag_answer)
st.divider()
# Show each result
for res in top_results:
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
# Title
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
st.markdown(f"#### {title_html}", unsafe_allow_html=True)
# Description snippet
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Additional text
additional_text = (
f"**Objective:** {highlight_query(objective, var)}
"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}
"
f"**Projekt duration:** {start_year_str}-{end_year_str}
"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}
"
f"**Country:** {country_raw}
"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "transparenz@giz.de":
additional_text += f"
**Contact:** xxx@giz.de"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
else:
# Semantic results
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
top_results = filtered_semantic_no_dupe[:10]
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
st.markdown(f"{var}
", unsafe_allow_html=True)
st.write(rag_answer)
st.divider()
st.write("Showing **Top 15 Semantic Search results**")
for res in top_results:
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
st.markdown(f"#### {metadata['title']}")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
additional_text = (
f"**Objective:** {metadata.get('objective', '')}
"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}
"
f"**Projekt duration:** {start_year_str}-{end_year_str}
"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}
"
f"**Country:** {country_raw}
"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "transparenz@giz.de":
additional_text += f"
**Contact:** xxx@giz.de"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()