import streamlit as st import requests import pandas as pd from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year from appStore.prep_utils import create_documents, get_client from appStore.embed import hybrid_embed_chunks from appStore.search import hybrid_search from appStore.region_utils import load_region_data, get_country_name, get_regions from appStore.tfidf_extraction import extract_top_keywords from torch import cuda import json from datetime import datetime #model_config = getconfig("model_params.cfg") ########### # ToDo move to functions # Configuration for the dedicated model DEDICATED_MODEL = "meta-llama/Llama-3.1-8B-Instruct" DEDICATED_ENDPOINT = "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud" # Write access token from the settings WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"] def get_rag_answer(query, top_results): """ Constructs a prompt from the query and the page contexts of the top results, truncates the context to avoid exceeding the token limit, then sends it to the dedicated endpoint and returns only the generated answer. """ # Combine the context from the top results (adjust the separator as needed) context = "\n\n".join([res.payload["page_content"] for res in top_results]) # Truncate the context to a maximum number of characters (e.g., 12000 characters) max_context_chars = 15000 if len(context) > max_context_chars: context = context[:max_context_chars] # Build the prompt, instructing the model to only output the final answer. prompt = ( "Using the following context, answer the question concisely. " "Only output the final answer below, without repeating the context or question.\n\n" f"Context:\n{context}\n\n" f"Question: {query}\n\n" "Answer:" ) headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"} payload = { "inputs": prompt, "parameters": { "max_new_tokens": 150 # Adjust max tokens as needed } } response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload) if response.status_code == 200: result = response.json() answer = result[0]["generated_text"] # If the model returns the full prompt, split and extract only the portion after "Answer:" if "Answer:" in answer: answer = answer.split("Answer:")[-1].strip() return answer else: return f"Error in generating answer: {response.text}" ####### # get the device to be used eithe gpu or cpu device = 'cuda' if cuda.is_available() else 'cpu' st.set_page_config(page_title="SEARCH IATI",layout='wide') st.title("GIZ Project Database (PROTOTYPE)") var = st.text_input("Enter Search Query") # Load the region lookup CSV region_lookup_path = "docStore/regions_lookup.csv" region_df = load_region_data(region_lookup_path) #################### Create the embeddings collection and save ###################### # the steps below need to be performed only once and then commented out any unnecssary compute over-run ##### First we process and create the chunks for relvant data source #chunks = process_giz_worldwide() ##### Convert to langchain documents #temp_doc = create_documents(chunks,'chunks') ##### Embed and store docs, check if collection exist then you need to update the collection collection_name = "giz_worldwide" #hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True) ################### Hybrid Search ##################################################### client = get_client() print(client.get_collections()) # Get the maximum end_year across the entire collection max_end_year = get_max_end_year(client, collection_name) # Get all unique sub-regions _, unique_sub_regions = get_regions(region_df) # Fetch unique country codes and map to country names @st.cache_data def get_country_name_and_region_mapping(_client, collection_name, region_df): results = hybrid_search(_client, "", collection_name) country_set = set() for res in results[0] + results[1]: countries = res.payload.get('metadata', {}).get('countries', "[]") try: country_list = json.loads(countries.replace("'", '"')) # Only add codes of length 2 two_digit_codes = [code.upper() for code in country_list if len(code) == 2] country_set.update(two_digit_codes) except json.JSONDecodeError: pass # Create a mapping of {CountryName -> ISO2Code} and {ISO2Code -> SubRegion} country_name_to_code = {} iso_code_to_sub_region = {} for code in country_set: name = get_country_name(code, region_df) sub_region_row = region_df[region_df['alpha-2'] == code] sub_region = sub_region_row['sub-region'].values[0] if not sub_region_row.empty else "Not allocated" country_name_to_code[name] = code iso_code_to_sub_region[code] = sub_region return country_name_to_code, iso_code_to_sub_region # Get country name and region mappings client = get_client() country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df) unique_country_names = sorted(country_name_mapping.keys()) # List of country names # Layout filters in columns col1, col2, col3, col4 = st.columns([1, 1, 1, 4]) # Region filter with col1: region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions)) # Display region names # Dynamically filter countries based on selected region if region_filter == "All/Not allocated": filtered_country_names = unique_country_names # Show all countries if no region is selected else: filtered_country_names = [ name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter ] # Country filter with col2: country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names) # Display filtered country names # Year range slider # ToDo add end_year filter again with col3: current_year = datetime.now().year default_start_year = current_year - 5 # 3) The max_value is now the actual max end_year from collection end_year_range = st.slider( "Project End Year", min_value=2010, max_value=max_end_year, value=(default_start_year, max_end_year), ) # Checkbox to control whether to show only exact matches show_exact_matches = st.checkbox("Show only exact matches", value=False) def filter_results(results, country_filter, region_filter, end_year_range): ## ToDo add end_year filter again filtered = [] for r in results: metadata = r.payload.get('metadata', {}) countries = metadata.get('countries', "[]") year_str = metadata.get('end_year') if year_str: extracted = extract_year(year_str) try: end_year_val = int(extracted) if extracted != "Unknown" else 0 except ValueError: end_year_val = 0 else: end_year_val = 0 # Convert countries to a list try: c_list = json.loads(countries.replace("'", '"')) c_list = [code.upper() for code in c_list if len(code) == 2] except json.JSONDecodeError: c_list = [] # Translate selected country name to iso2 selected_iso_code = country_name_mapping.get(country_filter, None) # Check if any country in the metadata matches the selected region if region_filter != "All/Not allocated": countries_in_region = [code for code in c_list if iso_code_to_sub_region.get(code) == region_filter] else: countries_in_region = c_list # Filtering if ( (country_filter == "All/Not allocated" or selected_iso_code in c_list) and (region_filter == "All/Not allocated" or countries_in_region) and (end_year_range[0] <= end_year_val <= end_year_range[1]) # ToDo add end_year filter again ): filtered.append(r) return filtered # Run the search # 1) Adjust limit so we get more than 15 results results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200 # results is a tuple: (semantic_results, lexical_results) semantic_all = results[0] lexical_all = results[1] # 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score) semantic_all = [ r for r in semantic_all if len(r.payload["page_content"]) >= 5 ] lexical_all = [ r for r in lexical_all if len(r.payload["page_content"]) >= 5 ] # 2) Apply a threshold to SEMANTIC results (score >= 0.4) semantic_thresholded = [r for r in semantic_all if r.score >= 0.0] # 2) Filter the entire sets filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range) ## ToDo add end_year filter again filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range)## ToDo add end_year filter again filtered_semantic_no_dupe = remove_duplicates(filtered_semantic) # ToDo remove duplicates again? filtered_lexical_no_dupe = remove_duplicates(filtered_lexical) # Define a helper function to format currency values def format_currency(value): try: # Convert to float then int for formatting (assumes whole numbers) return f"€{int(float(value)):,}" except (ValueError, TypeError): return value # 3) Retrieve top 15 *after* filtering # Check user preference if show_exact_matches: # 1) Display heading st.write(f"Showing **Top 15 Lexical Search results** for query: {var}") # 2) Do a simple substring check (case-insensitive) # We'll create a new list lexical_substring_filtered query_substring = var.strip().lower() lexical_substring_filtered = [] for r in lexical_all: # page_content in lowercase page_text_lower = r.payload["page_content"].lower() # Keep this result only if the query substring is found if query_substring in page_text_lower: lexical_substring_filtered.append(r) # 3) Now apply your region/country/year filter on that new list filtered_lexical = filter_results( lexical_substring_filtered, country_filter, region_filter, end_year_range ) ## ToDo add end_year filter again # 4) Remove duplicates filtered_lexical_no_dupe = remove_duplicates(filtered_lexical) # 5) If empty after substring + filters + dedupe, show a custom message if not filtered_lexical_no_dupe: st.write('No exact matches, consider unchecking "Show only exact matches"') else: top_results = filtered_lexical_no_dupe[:5] rag_answer = get_rag_answer(var, top_results) st.markdown("### Generated Answer") st.write(rag_answer) st.divider() for res in top_results: # Metadata metadata = res.payload.get('metadata', {}) countries = metadata.get('countries', "[]") client_name = metadata.get('client', 'Unknown Client') start_year = metadata.get('start_year', None) end_year = metadata.get('end_year', None) total_volume = metadata.get('total_volume', "Unknown") total_project = metadata.get('total_project', "Unknown") id = metadata.get('id', "Unknown") project_name = res.payload['metadata'].get('project_name', 'Project Link') proj_id = metadata.get('id', 'Unknown') st.markdown(f"#### {project_name} [{proj_id}]") # Snippet logic (80 words) # Build snippet from objectives and descriptions. objectives = metadata.get("objectives", "") desc_de = metadata.get("description.de", "") desc_en = metadata.get("description.en", "") description = desc_de if desc_de else desc_en full_snippet = f"Objective: {objectives} Description: {description}" words = full_snippet.split() preview_word_count = 200 preview_text = " ".join(words[:preview_word_count]) remainder_text = " ".join(words[preview_word_count:]) st.write(preview_text + ("..." if remainder_text else "")) # Keywords full_text = res.payload['page_content'] top_keywords = extract_top_keywords(full_text, top_n=5) if top_keywords: st.markdown(f"_{' · '.join(top_keywords)}_") try: c_list = json.loads(countries.replace("'", '"')) except json.JSONDecodeError: c_list = [] # Only keep country names if the region lookup returns a different value. matched_countries = [] for code in c_list: if len(code) == 2: resolved_name = get_country_name(code.upper(), region_df) if resolved_name.upper() != code.upper(): matched_countries.append(resolved_name) # Format the year range start_year_str = extract_year(start_year) if start_year else "Unknown" end_year_str = extract_year(end_year) if end_year else "Unknown" formatted_project_budget = format_currency(total_project) formatted_total_volume = format_currency(total_volume) # Build the final string including a new row for countries. if matched_countries: additional_text = ( f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**\n" f"Projekt duration **{start_year_str}-{end_year_str}**\n" f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n" f"Country: **{', '.join(matched_countries)}**" ) else: additional_text = ( f"Commissioned by **{client_name}**\n" f"Projekt duration **{start_year_str}-{end_year_str}**\n" f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n" f"Country: **{', '.join(c_list) if c_list else 'Unknown'}**" ) st.markdown(additional_text) st.divider() else: st.write(f"Showing **Top 15 Semantic Search results** for query: {var}") if not filtered_semantic_no_dupe: st.write("No relevant results found.") else: # Get the top 15 results for the RAG context top_results = filtered_semantic_no_dupe[:5] # Call the RAG function to generate an answer rag_answer = get_rag_answer(var, top_results) # Display the generated answer at the top of the page st.markdown("### Generated Answer") st.write(rag_answer) st.divider() # Now list each individual search result below for res in top_results: # Metadata metadata = res.payload.get('metadata', {}) countries = metadata.get('countries', "[]") client_name = metadata.get('client', 'Unknown Client') start_year = metadata.get('start_year', None) end_year = metadata.get('end_year', None) total_volume = metadata.get('total_volume', "Unknown") total_project = metadata.get('total_project', "Unknown") id = metadata.get('id', "Unknown") project_name = res.payload['metadata'].get('project_name', 'Project Link') proj_id = metadata.get('id', 'Unknown') st.markdown(f"#### {project_name} [{proj_id}]") # Snippet logic (80 words) # Build snippet from objectives and descriptions. objectives = metadata.get("objectives", "") desc_de = metadata.get("description.de", "") desc_en = metadata.get("description.en", "") description = desc_de if desc_de else desc_en full_snippet = f"Objective: {objectives} Description: {description}" words = full_snippet.split() preview_word_count = 200 preview_text = " ".join(words[:preview_word_count]) remainder_text = " ".join(words[preview_word_count:]) st.write(preview_text + ("..." if remainder_text else "")) # Keywords full_text = res.payload['page_content'] top_keywords = extract_top_keywords(full_text, top_n=5) if top_keywords: st.markdown(f"_{' · '.join(top_keywords)}_") try: c_list = json.loads(countries.replace("'", '"')) except json.JSONDecodeError: c_list = [] matched_countries = [] for code in c_list: if len(code) == 2: resolved_name = get_country_name(code.upper(), region_df) if resolved_name.upper() != code.upper(): matched_countries.append(resolved_name) # Format the year range start_year_str = extract_year(start_year) if start_year else "Unknown" end_year_str = extract_year(end_year) if end_year else "Unknown" formatted_project_budget = format_currency(total_project) formatted_total_volume = format_currency(total_volume) # Build the final string if matched_countries: additional_text = ( f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**\n" f"Projekt duration **{start_year_str}-{end_year_str}**\n" f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n" f"Country: **{', '.join(matched_countries)}**" ) else: additional_text = ( f"Commissioned by **{client_name}**\n" f"Projekt duration **{start_year_str}-{end_year_str}**\n" f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n" f"Country: **{', '.join(c_list) if c_list else 'Unknown'}**" ) st.markdown(additional_text) st.divider() # for i in results: # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main'])) # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}") # st.write(i.page_content) # st.divider()