diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,2785 +1,2785 @@
-# app.py
-
-import streamlit as st
-
-# Set page config first, before any other st commands
-st.set_page_config(page_title="SNAP", layout="wide")
-
-# Add warning filters
-import warnings
-# More specific warning filters for torch.classes
-warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*')
-warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*')
-
-import pandas as pd
-import numpy as np
-import os
-import io
-import time
-from datetime import datetime
-import base64
-import re
-import pickle
-from typing import List, Dict, Any, Tuple
-import plotly.express as px
-import torch
-
-# For parallelism
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
-
-# Import necessary libraries for embeddings, clustering, and summarization
-from sentence_transformers import SentenceTransformer
-from sklearn.metrics.pairwise import cosine_similarity
-from bertopic import BERTopic
-from hdbscan import HDBSCAN
-import nltk
-from nltk.corpus import stopwords
-from nltk.tokenize import word_tokenize
-
-# For summarization and chat
-from langchain.chains import LLMChain
-from langchain_community.chat_models import ChatOpenAI
-from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
-from openai import OpenAI
-from transformers import GPT2TokenizerFast
-
-# Initialize OpenAI client and tokenizer
-client = OpenAI()
-
-###############################################################################
-# Helper: Attempt to get this file's directory or fallback to current working dir
-###############################################################################
-def get_base_dir():
- try:
- base_dir = os.path.dirname(__file__)
- if not base_dir:
- return os.getcwd()
- return base_dir
- except NameError:
- # In case __file__ is not defined (some environments)
- return os.getcwd()
-
-BASE_DIR = get_base_dir()
-
-# Function to get or create model directory
-def get_model_dir():
- base_dir = get_base_dir()
- model_dir = os.path.join(base_dir, 'models')
- os.makedirs(model_dir, exist_ok=True)
- return model_dir
-
-# Function to load tokenizer from local storage or download
-def load_tokenizer():
- model_dir = get_model_dir()
- tokenizer_dir = os.path.join(model_dir, 'tokenizer')
- os.makedirs(tokenizer_dir, exist_ok=True)
-
- try:
- # Try to load from local directory first
- tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir)
- #st.success("Loaded tokenizer from local storage")
- except Exception as e:
- #st.warning("Downloading tokenizer (one-time operation)...")
- try:
- # Download and save to local directory
- tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Use standard GPT2 tokenizer
- tokenizer.save_pretrained(tokenizer_dir)
- #st.success("Downloaded and saved tokenizer")
- except Exception as download_e:
- #st.error(f"Error downloading tokenizer: {str(download_e)}")
- raise
-
- return tokenizer
-
-# Load tokenizer
-try:
- tokenizer = load_tokenizer()
-except Exception as e:
- #st.error("Failed to load tokenizer. Some functionality may be limited.")
- tokenizer = None
-
-MAX_CONTEXT_WINDOW = 128000 # GPT-4o context window size
-
-# Initialize chat history in session state if not exists
-if 'chat_history' not in st.session_state:
- st.session_state.chat_history = []
-
-###############################################################################
-# Helper: Get chat response from OpenAI
-###############################################################################
-def get_chat_response(messages):
- try:
- response = client.chat.completions.create(
- model="gpt-4o-mini",
- messages=messages,
- temperature=0,
- )
- return response.choices[0].message.content.strip()
- except Exception as e:
- st.error(f"Error querying OpenAI: {e}")
- return None
-
-###############################################################################
-# Helper: Generate raw summary for a cluster (without references)
-###############################################################################
-def generate_raw_cluster_summary(
- topic_val: int,
- cluster_df: pd.DataFrame,
- llm: Any,
- chat_prompt: Any
-) -> Dict[str, Any]:
- """Generate a summary for a single cluster without reference enhancement,
- automatically trimming text if it exceeds a safe token limit."""
- cluster_text = " ".join(cluster_df['text'].tolist())
- if not cluster_text.strip():
- return None
-
- # Define a safe limit (95% of max context window to leave room for prompts)
- safe_limit = int(MAX_CONTEXT_WINDOW * 0.95)
-
- # Encode the text into tokens
- encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False)
-
- # If the text is too large, slice it
- if len(encoded_text) > safe_limit:
- #st.warning(f"Cluster {topic_val} text is too large ({len(encoded_text)} tokens). Trimming to {safe_limit} tokens.")
- encoded_text = encoded_text[:safe_limit]
- cluster_text = tokenizer.decode(encoded_text)
-
- user_prompt_local = f"**Text to summarize**: {cluster_text}"
- try:
- local_chain = LLMChain(llm=llm, prompt=chat_prompt)
- summary_local = local_chain.run(user_prompt=user_prompt_local).strip()
- return {'Topic': topic_val, 'Summary': summary_local}
- except Exception as e:
- st.error(f"Error generating summary for cluster {topic_val}: {str(e)}")
- return None
-
-###############################################################################
-# Helper: Enhance a summary with references
-###############################################################################
-def enhance_summary_with_references(
- summary_dict: Dict[str, Any],
- df_scope: pd.DataFrame,
- reference_id_column: str,
- url_column: str = None,
- llm: Any = None
-) -> Dict[str, Any]:
- """Add references to a summary."""
- if not summary_dict or 'Summary' not in summary_dict:
- return summary_dict
-
- try:
- cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']]
- enhanced = add_references_to_summary(
- summary_dict['Summary'],
- cluster_df,
- reference_id_column,
- url_column,
- llm
- )
- summary_dict['Enhanced_Summary'] = enhanced
- return summary_dict
- except Exception as e:
- st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}")
- return summary_dict
-
-###############################################################################
-# Helper: Process summaries in parallel
-###############################################################################
-def process_summaries_in_parallel(
- df_scope: pd.DataFrame,
- unique_selected_topics: List[int],
- llm: Any,
- chat_prompt: Any,
- enable_references: bool = False,
- reference_id_column: str = None,
- url_column: str = None,
- max_workers: int = 16
-) -> List[Dict[str, Any]]:
- """Process multiple cluster summaries in parallel using ThreadPoolExecutor."""
- summaries = []
- total_topics = len(unique_selected_topics)
-
- # Create progress placeholders
- progress_text = st.empty()
- progress_bar = st.progress(0)
-
- try:
- # Phase 1: Generate raw summaries in parallel
- progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)")
- completed_summaries = 0
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit summary generation tasks
- future_to_topic = {
- executor.submit(
- generate_raw_cluster_summary,
- topic_val,
- df_scope[df_scope['Topic'] == topic_val],
- llm,
- chat_prompt
- ): topic_val
- for topic_val in unique_selected_topics
- }
-
- # Process completed summary tasks
- for future in future_to_topic:
- try:
- result = future.result()
- if result:
- summaries.append(result)
- completed_summaries += 1
- # Update progress
- progress = completed_summaries / total_topics
- progress_bar.progress(progress)
- progress_text.text(
- f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)"
- )
- except Exception as e:
- topic_val = future_to_topic[future]
- st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}")
- completed_summaries += 1
- continue
-
- # Phase 2: Enhance summaries with references in parallel (if enabled)
- if enable_references and reference_id_column and summaries:
- total_to_enhance = len(summaries)
- completed_enhancements = 0
- progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)")
- progress_bar.progress(0)
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit reference enhancement tasks
- future_to_summary = {
- executor.submit(
- enhance_summary_with_references,
- summary_dict,
- df_scope,
- reference_id_column,
- url_column,
- llm
- ): summary_dict.get('Topic')
- for summary_dict in summaries
- }
-
- # Process completed enhancement tasks
- enhanced_summaries = []
- for future in future_to_summary:
- try:
- result = future.result()
- if result:
- enhanced_summaries.append(result)
- completed_enhancements += 1
- # Update progress
- progress = completed_enhancements / total_to_enhance
- progress_bar.progress(progress)
- progress_text.text(
- f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)"
- )
- except Exception as e:
- topic_val = future_to_summary[future]
- st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}")
- completed_enhancements += 1
- continue
-
- summaries = enhanced_summaries
-
- # Phase 3: Generate cluster names in parallel
- if summaries:
- total_to_name = len(summaries)
- completed_names = 0
- progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)")
- progress_bar.progress(0)
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit cluster naming tasks
- future_to_summary = {
- executor.submit(
- generate_cluster_name,
- summary_dict.get('Enhanced_Summary', summary_dict['Summary']),
- llm
- ): summary_dict.get('Topic')
- for summary_dict in summaries
- }
-
- # Process completed naming tasks
- named_summaries = []
- for future in future_to_summary:
- try:
- cluster_name = future.result()
- topic_val = future_to_summary[future]
- # Find the corresponding summary dict
- summary_dict = next(s for s in summaries if s['Topic'] == topic_val)
- summary_dict['Cluster_Name'] = cluster_name
- named_summaries.append(summary_dict)
- completed_names += 1
- # Update progress
- progress = completed_names / total_to_name
- progress_bar.progress(progress)
- progress_text.text(
- f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)"
- )
- except Exception as e:
- topic_val = future_to_summary[future]
- st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}")
- completed_names += 1
- continue
-
- summaries = named_summaries
- finally:
- # Clean up progress indicators
- progress_text.empty()
- progress_bar.empty()
-
- return summaries
-
-###############################################################################
-# Helper: Generate cluster name
-###############################################################################
-def generate_cluster_name(summary_text: str, llm: Any) -> str:
- """Generate a concise, descriptive name for a cluster based on its summary."""
- system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster.
-
-Rules:
-1. Keep it between 3-6 words
-2. Be specific but concise
-3. Capture the main theme/focus
-4. Use title case
-4. Do not include words like "Cluster", "Topic", or "Theme"
-5. Focus on the content, not metadata
-
-Example good names:
-- Agricultural Water Management Innovation
-- Gender Equality in Farming
-- Climate-Smart Village Implementation
-- Sustainable Livestock Practices"""
-
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"}
- ]
-
- try:
- response = get_chat_response(messages)
- # Clean up response (remove quotes, newlines, etc.)
- cluster_name = response.strip().strip('"').strip("'").strip()
- return cluster_name
- except Exception as e:
- st.error(f"Error generating cluster name: {str(e)}")
- return "Unnamed Cluster"
-
-###############################################################################
-# Helper: Attempt to get this file's directory or fallback to current working dir
-###############################################################################
-def get_base_dir():
- try:
- base_dir = os.path.dirname(__file__)
- if not base_dir:
- return os.getcwd()
- return base_dir
- except NameError:
- # In case __file__ is not defined (some environments)
- return os.getcwd()
-
-BASE_DIR = get_base_dir()
-
-###############################################################################
-# NLTK Resource Initialization
-###############################################################################
-def init_nltk_resources():
- """Initialize NLTK resources with better error handling and less verbose output"""
- nltk.data.path.append('/home/appuser/nltk_data') # Ensure consistent data path
-
- resources = {
- 'tokenizers/punkt': 'punkt_tab', # Updated to use punkt_tab
- 'corpora/stopwords': 'stopwords'
- }
-
- for resource_path, resource_name in resources.items():
- try:
- nltk.data.find(resource_path)
- except LookupError:
- try:
- nltk.download(resource_name, quiet=True)
- except Exception as e:
- st.warning(f"Error downloading NLTK resource {resource_name}: {e}")
-
- # Test tokenizer silently
- try:
- from nltk.tokenize import PunktSentenceTokenizer
- tokenizer = PunktSentenceTokenizer()
- tokenizer.tokenize("Test sentence.")
- except Exception as e:
- st.error(f"Error initializing NLTK tokenizer: {e}")
- try:
- nltk.download('punkt_tab', quiet=True) # Updated to use punkt_tab
- except Exception as e:
- st.error(f"Failed to download punkt_tab tokenizer: {e}")
-
-# Initialize NLTK resources
-init_nltk_resources()
-
-###############################################################################
-# Function: add_references_to_summary
-###############################################################################
-def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None):
- """
- Add references to a summary by identifying which parts of the summary come
- from which source documents. References will be appended as [ID],
- optionally linked if a URL column is provided.
-
- Args:
- summary (str): The summary text to enhance with references.
- source_df (DataFrame): DataFrame containing the source documents.
- reference_column (str): Column name to use for reference IDs.
- url_column (str, optional): Column name containing URLs for hyperlinks.
- llm (LLM, optional): Language model for source attribution.
- Returns:
- str: Enhanced summary with references as HTML if possible.
- """
- if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns:
- return summary
-
- # If no LLM is provided, we can't do source attribution
- if llm is None:
- return summary
-
- # Split the summary into paragraphs first
- paragraphs = summary.split('\n\n')
- enhanced_paragraphs = []
-
- # Prepare source texts with their reference IDs
- source_texts = []
- reference_ids = []
- urls = []
- for _, row in source_df.iterrows():
- if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]):
- source_texts.append(str(row['text']))
- reference_ids.append(str(row[reference_column]))
- if url_column and url_column in row and pd.notna(row[url_column]):
- urls.append(str(row[url_column]))
- else:
- urls.append(None)
- if not source_texts:
- return summary
-
- # Create a mapping between URLs and reference IDs
- url_map = {}
- for ref_id, u in zip(reference_ids, urls):
- if u:
- url_map[ref_id] = u
-
- # Define the system prompt for source attribution
- system_prompt = """
- You are an expert at identifying the source of information. You will be given:
- 1. A sentence or bullet point from a summary
- 2. A list of source texts with their IDs
-
- Your task is to identify which source text(s) the text most likely came from.
- Return ONLY the IDs of the source texts that contributed to the text, separated by commas.
- If you cannot confidently attribute the text to any source, return "unknown".
- """
-
- for paragraph in paragraphs:
- if not paragraph.strip():
- enhanced_paragraphs.append('')
- continue
-
- # Check if it's a bullet point list
- if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')):
- # Handle bullet points
- bullet_lines = paragraph.split('\n')
- enhanced_bullets = []
- for line in bullet_lines:
- if not line.strip():
- enhanced_bullets.append(line)
- continue
-
- if line.strip().startswith('- ') or line.strip().startswith('* '):
- # Process each bullet point
- user_prompt = f"""
- Text: {line.strip()}
-
- Source texts:
- {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
-
- Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown".
- """
-
- try:
- system_message = SystemMessagePromptTemplate.from_template(system_prompt)
- human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
- chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
- chain = LLMChain(llm=llm, prompt=chat_prompt)
- response = chain.run(user_prompt=user_prompt)
- source_ids = response.strip()
-
- if source_ids.lower() == "unknown":
- enhanced_bullets.append(line)
- else:
- # Extract just the IDs
- source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
- source_ids = re.sub(r'\s+', '', source_ids)
- ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
-
- if ids:
- ref_parts = []
- for id_ in ids:
- if id_ in url_map:
- ref_parts.append(f'{id_}')
- else:
- ref_parts.append(id_)
- ref_string = ", ".join(ref_parts)
- enhanced_bullets.append(f"{line} [{ref_string}]")
- else:
- enhanced_bullets.append(line)
- except Exception:
- enhanced_bullets.append(line)
- else:
- enhanced_bullets.append(line)
-
- enhanced_paragraphs.append('\n'.join(enhanced_bullets))
- else:
- # Handle regular paragraphs
- sentences = re.split(r'(?<=[.!?])\s+', paragraph)
- enhanced_sentences = []
-
- for sentence in sentences:
- if not sentence.strip():
- continue
-
- user_prompt = f"""
- Sentence: {sentence.strip()}
-
- Source texts:
- {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
-
- Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown".
- """
-
- try:
- system_message = SystemMessagePromptTemplate.from_template(system_prompt)
- human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
- chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
- chain = LLMChain(llm=llm, prompt=chat_prompt)
- response = chain.run(user_prompt=user_prompt)
- source_ids = response.strip()
-
- if source_ids.lower() == "unknown":
- enhanced_sentences.append(sentence)
- else:
- # Extract just the IDs
- source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
- source_ids = re.sub(r'\s+', '', source_ids)
- ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
-
- if ids:
- ref_parts = []
- for id_ in ids:
- if id_ in url_map:
- ref_parts.append(f'{id_}')
- else:
- ref_parts.append(id_)
- ref_string = ", ".join(ref_parts)
- enhanced_sentences.append(f"{sentence} [{ref_string}]")
- else:
- enhanced_sentences.append(sentence)
- except Exception:
- enhanced_sentences.append(sentence)
-
- enhanced_paragraphs.append(' '.join(enhanced_sentences))
-
- # Join paragraphs back together with double newlines to preserve formatting
- return '\n\n'.join(enhanced_paragraphs)
-
-
-st.sidebar.image("static/SNAP_logo.png", width=350)
-
-###############################################################################
-# Device / GPU Info
-###############################################################################
-device = 'cuda' if torch.cuda.is_available() else 'cpu'
-if device == 'cuda':
- st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
-else:
- st.sidebar.info("Using CPU")
-
-###############################################################################
-# Load or Compute Embeddings
-###############################################################################
-@st.cache_resource
-def get_embedding_model():
- model_dir = get_model_dir()
- st_model_dir = os.path.join(model_dir, 'sentence_transformer')
- os.makedirs(st_model_dir, exist_ok=True)
-
- model_name = 'all-MiniLM-L6-v2'
- try:
- # Try to load from local directory first
- model = SentenceTransformer(st_model_dir)
- #st.success("Loaded sentence transformer from local storage")
- except Exception as e:
- #st.warning("Downloading sentence transformer model (one-time operation)...")
- try:
- # Download and save to local directory
- model = SentenceTransformer(model_name)
- model.save(st_model_dir)
- #st.success("Downloaded and saved sentence transformer model")
- except Exception as download_e:
- st.error(f"Error downloading sentence transformer model: {str(download_e)}")
- raise
-
- return model.to(device)
-
-def generate_embeddings(texts, model):
- with st.spinner('Calculating embeddings...'):
- embeddings = model.encode(texts, show_progress_bar=True, device=device)
- return embeddings
-
-@st.cache_data
-def load_default_dataset(default_dataset_path):
- if os.path.exists(default_dataset_path):
- df_ = pd.read_excel(default_dataset_path)
- return df_
- else:
- st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.")
- return None
-
-@st.cache_data
-def load_uploaded_dataset(uploaded_file):
- df_ = pd.read_excel(uploaded_file)
- return df_
-
-def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None):
- """
- Loads pre-computed embeddings from a pickle file if they match current data,
- otherwise computes and caches them.
- """
- if not text_columns:
- return None, None
-
- base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset"
- if uploaded_file_name:
- base_name = os.path.splitext(uploaded_file_name)[0]
-
- cols_key = "_".join(sorted(text_columns))
- timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
-
- embeddings_dir = BASE_DIR
- if using_default_dataset:
- embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl')
- else:
- # For custom dataset, we still try to avoid regenerating each time
- embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl")
-
- df_fill = df.fillna("")
- texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist()
-
- # If already in session_state with matching columns and length, reuse
- if ('embeddings' in st.session_state
- and 'last_text_columns' in st.session_state
- and st.session_state['last_text_columns'] == text_columns
- and len(st.session_state['embeddings']) == len(texts)):
- return st.session_state['embeddings'], st.session_state.get('embeddings_file', None)
-
- # Try to load from disk
- if os.path.exists(embeddings_file):
- with open(embeddings_file, 'rb') as f:
- embeddings = pickle.load(f)
- if len(embeddings) == len(texts):
- st.write("Loaded pre-calculated embeddings.")
- st.session_state['embeddings'] = embeddings
- st.session_state['embeddings_file'] = embeddings_file
- st.session_state['last_text_columns'] = text_columns
- return embeddings, embeddings_file
-
- # Otherwise compute
- st.write("Generating embeddings...")
- model = get_embedding_model()
- embeddings = generate_embeddings(texts, model)
- with open(embeddings_file, 'wb') as f:
- pickle.dump(embeddings, f)
-
- st.session_state['embeddings'] = embeddings
- st.session_state['embeddings_file'] = embeddings_file
- st.session_state['last_text_columns'] = text_columns
- return embeddings, embeddings_file
-
-
-###############################################################################
-# Reset Filter Function
-###############################################################################
-def reset_filters():
- st.session_state['selected_additional_filters'] = {}
-
-# Selector de vista
-st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view")
-
-if st.session_state.view == "Power User Mode":
- st.header("Power User Mode")
- ###############################################################################
- # Sidebar: Dataset Selection
- ###############################################################################
- st.sidebar.title("Data Selection")
- dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset'))
-
- if 'df' not in st.session_state:
- st.session_state['df'] = pd.DataFrame()
- if 'filtered_df' not in st.session_state:
- st.session_state['filtered_df'] = pd.DataFrame()
-
- if dataset_option == 'PRMS 2022+2023+2024 QAed':
- default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
- df = load_default_dataset(default_dataset_path)
- if df is not None:
- st.session_state['df'] = df.copy()
- st.session_state['using_default_dataset'] = True
-
- # Initialize filtered_df with full dataset by default
- if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty:
- st.session_state['filtered_df'] = df.copy()
-
- # Initialize filter_state if not exists
- if 'filter_state' not in st.session_state:
- st.session_state['filter_state'] = {
- 'applied': False,
- 'filters': {}
- }
-
- # Set default text columns if not already set
- if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
- default_text_cols = []
- if 'Title' in df.columns and 'Description' in df.columns:
- default_text_cols = ['Title', 'Description']
- st.session_state['text_columns'] = default_text_cols
-
- #st.write("Using default dataset:")
- #st.write("Data Preview:")
- #st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
- #st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
-
- df_cols = df.columns.tolist()
-
- # Additional filter columns
- st.subheader("Select Filters")
- if 'additional_filters_selected' not in st.session_state:
- st.session_state['additional_filters_selected'] = []
- if 'filter_values' not in st.session_state:
- st.session_state['filter_values'] = {}
-
- with st.form("filter_selection_form"):
- all_columns = df.columns.tolist()
- selected_additional_cols = st.multiselect(
- "Select columns from your dataset to use as filters:",
- all_columns,
- default=st.session_state['additional_filters_selected']
- )
- add_filters_submitted = st.form_submit_button("Add Additional Filters")
-
- if add_filters_submitted:
- if selected_additional_cols != st.session_state['additional_filters_selected']:
- st.session_state['additional_filters_selected'] = selected_additional_cols
- # Reset removed columns
- st.session_state['filter_values'] = {
- k: v for k, v in st.session_state['filter_values'].items()
- if k in selected_additional_cols
- }
-
- # Show dynamic filters form if any selected columns
- if st.session_state['additional_filters_selected']:
- st.subheader("Apply Filters")
-
- # Quick search section (outside form)
- for col_name in st.session_state['additional_filters_selected']:
- unique_vals = sorted(df[col_name].dropna().unique().tolist())
-
- # Add a search box for quick selection
- search_key = f"search_{col_name}"
- if search_key not in st.session_state:
- st.session_state[search_key] = ""
-
- col1, col2 = st.columns([3, 1])
- with col1:
- search_term = st.text_input(
- f"Search in {col_name}",
- key=search_key,
- help="Enter text to find and select all matching values"
- )
- with col2:
- if st.button(f"Select Matching", key=f"select_{col_name}"):
- # Handle comma-separated values
- if search_term:
- matching_vals = [
- val for val in unique_vals
- if any(search_term.lower() in str(part).lower()
- for part in (val.split(',') if isinstance(val, str) else [val]))
- ]
- # Update the multiselect default value
- current_selected = st.session_state['filter_values'].get(col_name, [])
- st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals))
-
- # Show feedback about matches
- if matching_vals:
- st.success(f"Found and selected {len(matching_vals)} matching values")
- else:
- st.warning("No matching values found")
-
- # Filter application form
- with st.form("apply_filters_form"):
- for col_name in st.session_state['additional_filters_selected']:
- unique_vals = sorted(df[col_name].dropna().unique().tolist())
- selected_vals = st.multiselect(
- f"Filter by {col_name}",
- options=unique_vals,
- default=st.session_state['filter_values'].get(col_name, [])
- )
- st.session_state['filter_values'][col_name] = selected_vals
-
- # Add clear filters button and apply filters button
- col1, col2 = st.columns([1, 4])
- with col1:
- clear_filters = st.form_submit_button("Clear All")
- with col2:
- apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset")
-
- if clear_filters:
- st.session_state['filter_values'] = {}
- # Clear any existing summary data when filters are cleared
- if 'summary_df' in st.session_state:
- del st.session_state['summary_df']
- if 'high_level_summary' in st.session_state:
- del st.session_state['high_level_summary']
- if 'enhanced_summary' in st.session_state:
- del st.session_state['enhanced_summary']
- st.rerun()
-
- # Text columns selection moved to Advanced Settings
- with st.expander("⚙️ Advanced Settings", expanded=False):
- st.subheader("**Select Text Columns for Embedding**")
- text_columns_selected = st.multiselect(
- "Text Columns:",
- df_cols,
- default=st.session_state['text_columns'],
- help="Choose columns containing text for semantic search and clustering. "
- "If multiple are selected, their text will be concatenated."
- )
- st.session_state['text_columns'] = text_columns_selected
-
- # Apply filters to the dataset
- filtered_df = df.copy()
- if 'apply_filters_submitted' in locals() and apply_filters_submitted:
- # Clear any existing summary data when new filters are applied
- if 'summary_df' in st.session_state:
- del st.session_state['summary_df']
- if 'high_level_summary' in st.session_state:
- del st.session_state['high_level_summary']
- if 'enhanced_summary' in st.session_state:
- del st.session_state['enhanced_summary']
-
- for col_name in st.session_state['additional_filters_selected']:
- selected_vals = st.session_state['filter_values'].get(col_name, [])
- if selected_vals:
- filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
- st.success("Filters applied successfully!")
- st.session_state['filtered_df'] = filtered_df.copy()
- st.session_state['filter_state'] = {
- 'applied': True,
- 'filters': st.session_state['filter_values'].copy()
- }
- # Reset any existing clustering results
- for k in ['clustered_data', 'topic_model', 'current_clustering_data',
- 'current_clustering_option', 'hierarchy']:
- if k in st.session_state:
- del st.session_state[k]
-
- elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']:
- # Reapply stored filters
- for col_name, selected_vals in st.session_state['filter_state']['filters'].items():
- if selected_vals:
- filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
- st.session_state['filtered_df'] = filtered_df.copy()
-
- # Show current data preview and download button
- if st.session_state['filtered_df'] is not None:
- if st.session_state['filter_state']['applied']:
- st.write("Filtered Data Preview:")
- else:
- st.write("Current Data Preview:")
- st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
- st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
-
- output = io.BytesIO()
- writer = pd.ExcelWriter(output, engine='openpyxl')
- st.session_state['filtered_df'].to_excel(writer, index=False)
- writer.close()
- processed_data = output.getvalue()
-
- st.download_button(
- label="Download Current Data",
- data=processed_data,
- file_name='data.xlsx',
- mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
- )
- else:
- st.warning("Please ensure the default dataset exists in the 'input' directory.")
-
- else:
- # Upload custom dataset
- uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"])
- if uploaded_file is not None:
- df = load_uploaded_dataset(uploaded_file)
- if df is not None:
- st.session_state['df'] = df.copy()
- st.session_state['using_default_dataset'] = False
- st.session_state['uploaded_file_name'] = uploaded_file.name
- st.write("Data preview:")
- st.write(df.head())
- df_cols = df.columns.tolist()
-
- st.subheader("**Select Text Columns for Embedding**")
- text_columns_selected = st.multiselect(
- "Text Columns:",
- df_cols,
- default=df_cols[:1] if df_cols else []
- )
- st.session_state['text_columns'] = text_columns_selected
-
- st.write("**Additional Filters**")
- selected_additional_cols = st.multiselect(
- "Select additional columns from your dataset to use as filters:",
- df_cols,
- default=[]
- )
- st.session_state['additional_filters_selected'] = selected_additional_cols
-
- filtered_df = df.copy()
- for col_name in selected_additional_cols:
- if f'selected_filter_{col_name}' not in st.session_state:
- st.session_state[f'selected_filter_{col_name}'] = []
- unique_vals = sorted(df[col_name].dropna().unique().tolist())
- selected_vals = st.multiselect(
- f"Filter by {col_name}",
- options=unique_vals,
- default=st.session_state[f'selected_filter_{col_name}']
- )
- st.session_state[f'selected_filter_{col_name}'] = selected_vals
- if selected_vals:
- filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
-
- st.session_state['filtered_df'] = filtered_df
- st.write("Filtered Data Preview:")
- st.dataframe(filtered_df.head(), hide_index=True)
- st.write(f"Total number of results: {len(filtered_df)}")
-
- output = io.BytesIO()
- writer = pd.ExcelWriter(output, engine='openpyxl')
- filtered_df.to_excel(writer, index=False)
- writer.close()
- processed_data = output.getvalue()
-
- st.download_button(
- label="Download Filtered Data",
- data=processed_data,
- file_name='filtered_data.xlsx',
- mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
- )
- else:
- st.warning("Failed to load the uploaded dataset.")
- else:
- st.warning("Please upload an Excel file to proceed.")
-
- if 'filtered_df' in st.session_state:
- st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
-
-
- ###############################################################################
- # Preserve active tab across reruns
- ###############################################################################
- if 'active_tab_index' not in st.session_state:
- st.session_state.active_tab_index = 0
-
- tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"]
- tabs = st.tabs(tabs_titles)
- # We just create these references so we can navigate more easily
- tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs
-
- ###############################################################################
- # Tab: Help
- ###############################################################################
- with tab_help:
- st.header("Help")
- st.markdown("""
- ### About SNAP
-
- SNAP allows you to explore, filter, search, cluster, and summarize textual datasets.
-
- **Workflow**:
- 1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own.
- 2. **Filtering**: Set additional filters for your dataset.
- 3. **Select Text Columns**: Which columns to embed.
- 4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents.
- 5. **Clustering** (Tab): Group documents into topics.
- 6. **Summarization** (Tab): Summarize the clustered documents (with optional references).
-
- ### Troubleshooting
- - If you see no results, try lowering the similarity threshold or removing negative/required keywords.
- - Ensure you have at least one text column selected for embeddings.
- """)
-
- ###############################################################################
- # Tab: Semantic Search
- ###############################################################################
- with tab_semantic:
- st.header("Semantic Search")
- if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
- text_columns = st.session_state.get('text_columns', [])
- if not text_columns:
- st.warning("No text columns selected. Please select at least one column for text embedding.")
- else:
- df_full = st.session_state['df']
- # Load or compute embeddings if necessary
- embeddings, _ = load_or_compute_embeddings(
- df_full,
- st.session_state.get('using_default_dataset', False),
- st.session_state.get('uploaded_file_name'),
- text_columns
- )
-
- if embeddings is not None:
- with st.expander("ℹ️ How Semantic Search Works", expanded=False):
- st.markdown("""
- ### Understanding Semantic Search
-
- Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works:
-
- 1. **Query Processing**:
- - Your search query is converted into a numerical representation (embedding) that captures its meaning
- - Example: Searching for "Climate Smart Villages" will understand the concept, not just the words
- - Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words
-
- 2. **Similarity Matching**:
- - Documents are ranked by how closely their meaning matches your query
- - The similarity threshold controls how strict this matching is
- - Higher threshold (e.g., 0.8) = more precise but fewer results
- - Lower threshold (e.g., 0.3) = more results but might be less relevant
-
- 3. **Advanced Features**:
- - **Negative Keywords**: Use to explicitly exclude documents containing certain terms
- - **Required Keywords**: Ensure specific terms appear in the results
- - These work as traditional keyword filters after the semantic search
-
- ### Search Tips
-
- - **Phrase Queries**: Enter complete phrases for better context
- - "Climate Smart Villages" (as one concept)
- - Better than separate terms: "climate", "smart", "villages"
-
- - **Descriptive Queries**: Add context for better results
- - Instead of: "water"
- - Better: "water management in agriculture"
-
- - **Conceptual Queries**: Focus on concepts rather than specific terms
- - Instead of: "increased yield"
- - Better: "agricultural productivity improvements"
-
- ### Example Searches
-
- 1. **Query**: "Climate Smart Villages"
- - Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development
- - Even if they don't use these exact words
-
- 2. **Query**: "Gender equality in agriculture"
- - Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development
- - Related concepts are captured semantically
-
- 3. **Query**: "Sustainable water management"
- + Required keyword: "irrigation"
- - Combines semantic understanding of water sustainability with specific irrigation focus
- """)
-
- with st.form("search_parameters"):
- query = st.text_input("Enter your search query:")
- include_keywords = st.text_input("Include only documents containing these words (comma-separated):")
- similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35)
- submitted = st.form_submit_button("Search")
-
- if submitted:
- if query.strip():
- with st.spinner("Performing Semantic Search..."):
- # Clear any existing summary data when new search is run
- if 'summary_df' in st.session_state:
- del st.session_state['summary_df']
- if 'high_level_summary' in st.session_state:
- del st.session_state['high_level_summary']
- if 'enhanced_summary' in st.session_state:
- del st.session_state['enhanced_summary']
-
- model = get_embedding_model()
- df_filtered = st.session_state['filtered_df'].fillna("")
- search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
-
- # Filter the embeddings to the same subset
- subset_indices = df_filtered.index
- subset_embeddings = embeddings[subset_indices]
-
- query_embedding = model.encode([query], device=device)
- similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
-
- # Show distribution
- fig = px.histogram(
- x=similarities,
- nbins=30,
- labels={'x': 'Similarity Score', 'y': 'Number of Documents'},
- title='Distribution of Similarity Scores'
- )
- fig.add_vline(
- x=similarity_threshold,
- line_dash="dash",
- line_color="red",
- annotation_text=f"Threshold: {similarity_threshold:.2f}",
- annotation_position="top"
- )
- st.write("### Similarity Score Distribution")
- st.plotly_chart(fig)
-
- above_threshold_indices = np.where(similarities > similarity_threshold)[0]
- if len(above_threshold_indices) == 0:
- st.warning("No results found above the similarity threshold.")
- if 'search_results' in st.session_state:
- del st.session_state['search_results']
- else:
- selected_indices = subset_indices[above_threshold_indices]
- results = df_filtered.loc[selected_indices].copy()
- results['similarity_score'] = similarities[above_threshold_indices]
- results.sort_values(by='similarity_score', ascending=False, inplace=True)
-
- # Include keyword filtering
- if include_keywords.strip():
- inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()]
- if inc_words:
- results = results[
- results.apply(
- lambda row: all(
- w in (' '.join(row.astype(str)).lower()) for w in inc_words
- ),
- axis=1
- )
- ]
-
- if results.empty:
- st.warning("No results found after applying keyword filters.")
- if 'search_results' in st.session_state:
- del st.session_state['search_results']
- else:
- st.session_state['search_results'] = results.copy()
- output = io.BytesIO()
- writer = pd.ExcelWriter(output, engine='openpyxl')
- results.to_excel(writer, index=False)
- writer.close()
- processed_data = output.getvalue()
- st.session_state['search_results_processed_data'] = processed_data
- else:
- st.warning("Please enter a query to search.")
-
- # Display search results if available
- if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
- st.write("## Search Results")
- results = st.session_state['search_results']
- cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score']
- st.dataframe(results[cols_to_display], hide_index=True)
- st.write(f"Total number of results: {len(results)}")
-
- if 'search_results_processed_data' in st.session_state:
- st.download_button(
- label="Download Full Results",
- data=st.session_state['search_results_processed_data'],
- file_name='search_results.xlsx',
- mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
- key='download_search_results'
- )
- else:
- st.info("No search results to display. Enter a query and click 'Search'.")
- else:
- st.warning("No embeddings available because no text columns were chosen.")
- else:
- st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.")
-
-
- ###############################################################################
- # Tab: Clustering
- ###############################################################################
- with tab_clustering:
- st.header("Clustering")
- if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
- # Add explanation about clustering
- with st.expander("ℹ️ How Clustering Works", expanded=False):
- st.markdown("""
- ### Understanding Document Clustering
-
- Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works:
-
- 1. **Cluster Formation**:
- - Documents are grouped based on their semantic similarity
- - Each cluster represents a distinct theme or topic
- - Documents that are too different from others may remain unclustered (labeled as -1)
- - The "Min Cluster Size" parameter controls how clusters are formed
-
- 2. **Interpreting Results**:
- - Each cluster is assigned a number (e.g., 0, 1, 2...)
- - Cluster -1 contains "outlier" documents that didn't fit well in other clusters
- - The size of each cluster indicates how common that theme is
- - Keywords for each cluster show the main topics/concepts
-
- 3. **Visualizations**:
- - **Intertopic Distance Map**: Shows how clusters relate to each other
- - Closer clusters are more semantically similar
- - Size of circles indicates number of documents
- - Hover to see top terms for each cluster
-
- - **Topic Document Visualization**: Shows individual documents
- - Each point is a document
- - Colors indicate cluster membership
- - Distance between points shows similarity
-
- - **Topic Hierarchy**: Shows how topics are related
- - Tree structure shows topic relationships
- - Parent topics contain broader themes
- - Child topics show more specific sub-themes
-
- ### How to Use Clusters
-
- 1. **Exploration**:
- - Use clusters to discover main themes in your data
- - Look for unexpected groupings that might reveal insights
- - Identify outliers that might need special attention
-
- 2. **Analysis**:
- - Compare cluster sizes to understand theme distribution
- - Examine keywords to understand what defines each cluster
- - Use hierarchy to see how themes are nested
-
- 3. **Practical Applications**:
- - Generate summaries for specific clusters
- - Focus detailed analysis on clusters of interest
- - Use clusters to organize and categorize documents
- - Identify gaps or overlaps in your dataset
-
- ### Tips for Better Results
-
- - **Adjust Min Cluster Size**:
- - Larger values (15-20): Fewer, broader clusters
- - Smaller values (2-5): More specific, smaller clusters
- - Balance between too many small clusters and too few large ones
-
- - **Choose Data Wisely**:
- - Cluster full dataset for overall themes
- - Cluster search results for focused analysis
- - More documents generally give better clusters
-
- - **Interpret with Context**:
- - Consider your domain knowledge
- - Look for patterns across multiple visualizations
- - Use cluster insights to guide further analysis
- """)
-
- df_to_cluster = None
-
- # Create a single form for clustering settings
- with st.form("clustering_form"):
- st.subheader("Clustering Settings")
-
- # Data source selection
- clustering_option = st.radio(
- "Select data for clustering:",
- ('Full Dataset', 'Filtered Dataset', 'Semantic Search Results')
- )
-
- # Clustering parameters
- min_cluster_size_val = st.slider(
- "Min Cluster Size",
- min_value=2,
- max_value=50,
- value=st.session_state.get('min_cluster_size', 5),
- help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)"
- )
-
- run_clustering = st.form_submit_button("Run Clustering")
-
- if run_clustering:
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.session_state['min_cluster_size'] = min_cluster_size_val
-
- # Decide which DataFrame is used based on the selection
- if clustering_option == 'Semantic Search Results':
- if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
- df_to_cluster = st.session_state['search_results'].copy()
- else:
- st.warning("No semantic search results found. Please run a search first.")
- elif clustering_option == 'Filtered Dataset':
- if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
- df_to_cluster = st.session_state['filtered_df'].copy()
- else:
- st.warning("Filtered dataset is empty. Please check your filters.")
- else: # Full Dataset
- if 'df' in st.session_state and not st.session_state['df'].empty:
- df_to_cluster = st.session_state['df'].copy()
-
- text_columns = st.session_state.get('text_columns', [])
- if not text_columns:
- st.warning("No text columns selected. Please select text columns to embed before clustering.")
- else:
- # Ensure embeddings are available
- df_full = st.session_state['df']
- embeddings, _ = load_or_compute_embeddings(
- df_full,
- st.session_state.get('using_default_dataset', False),
- st.session_state.get('uploaded_file_name'),
- text_columns
- )
-
- if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering:
- with st.spinner("Performing clustering..."):
- # Clear any existing summary data when clustering is run
- if 'summary_df' in st.session_state:
- del st.session_state['summary_df']
- if 'high_level_summary' in st.session_state:
- del st.session_state['high_level_summary']
- if 'enhanced_summary' in st.session_state:
- del st.session_state['enhanced_summary']
-
- dfc = df_to_cluster.copy().fillna("")
- dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
-
- # Filter embeddings to those rows
- selected_indices = dfc.index
- embeddings_clustering = embeddings[selected_indices]
-
- # Basic cleaning
- stop_words = set(stopwords.words('english'))
- texts_cleaned = []
- for text in dfc['text'].tolist():
- try:
- # First try with word_tokenize
- try:
- word_tokens = word_tokenize(text)
- except LookupError:
- # If punkt is missing, try downloading it again
- nltk.download('punkt_tab', quiet=False)
- word_tokens = word_tokenize(text)
- except Exception as e:
- # If word_tokenize fails, fall back to simple splitting
- st.warning(f"Using fallback tokenization due to error: {e}")
- word_tokens = text.split()
-
- filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
- texts_cleaned.append(filtered_text)
- except Exception as e:
- st.error(f"Error processing text: {e}")
- # Add the original text if processing fails
- texts_cleaned.append(text)
-
- try:
- # Validation checks before clustering
- if len(texts_cleaned) < min_cluster_size_val:
- st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.")
- st.session_state['clustering_error'] = "Insufficient documents for clustering"
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.stop()
-
- # Convert embeddings to CPU numpy if needed
- if torch.is_tensor(embeddings_clustering):
- embeddings_for_clustering = embeddings_clustering.cpu().numpy()
- else:
- embeddings_for_clustering = embeddings_clustering
-
- # Additional validation
- if embeddings_for_clustering.shape[0] != len(texts_cleaned):
- st.error("Mismatch between number of embeddings and texts.")
- st.session_state['clustering_error'] = "Embedding and text count mismatch"
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.stop()
-
- # Build the HDBSCAN model with error handling
- try:
- hdbscan_model = HDBSCAN(
- min_cluster_size=min_cluster_size_val,
- metric='euclidean',
- cluster_selection_method='eom'
- )
-
- # Build the BERTopic model
- topic_model = BERTopic(
- embedding_model=get_embedding_model(),
- hdbscan_model=hdbscan_model
- )
-
- # Fit the model and get topics
- topics, probs = topic_model.fit_transform(
- texts_cleaned,
- embeddings=embeddings_for_clustering
- )
-
- # Validate clustering results
- unique_topics = set(topics)
- if len(unique_topics) < 2:
- st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.")
- if -1 in unique_topics:
- non_noise_docs = sum(1 for t in topics if t != -1)
- st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).")
- if non_noise_docs < min_cluster_size_val:
- st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.")
- st.session_state['clustering_error'] = "Insufficient clustered documents"
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.stop()
-
- # Store results if validation passes
- dfc['Topic'] = topics
- st.session_state['topic_model'] = topic_model
- st.session_state['clustered_data'] = dfc.copy()
- st.session_state['clustering_texts_cleaned'] = texts_cleaned
- st.session_state['clustering_embeddings'] = embeddings_for_clustering
- st.session_state['clustering_completed'] = True
-
- # Try to generate visualizations with error handling
- try:
- st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
- except Exception as viz_error:
- st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.")
- st.session_state['intertopic_distance_fig'] = None
-
- try:
- st.session_state['topic_document_fig'] = topic_model.visualize_documents(
- texts_cleaned,
- embeddings=embeddings_for_clustering
- )
- except Exception as viz_error:
- st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.")
- st.session_state['topic_document_fig'] = None
-
- try:
- hierarchy = topic_model.hierarchical_topics(texts_cleaned)
- st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
- st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
- except Exception as viz_error:
- st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.")
- st.session_state['hierarchy'] = pd.DataFrame()
- st.session_state['hierarchy_fig'] = None
-
- except ValueError as ve:
- if "zero-size array to reduction operation maximum which has no identity" in str(ve):
- st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.")
- elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve):
- st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.")
- else:
- st.error(f"Clustering error: {str(ve)}")
- st.session_state['clustering_error'] = str(ve)
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.stop()
-
- except Exception as e:
- st.error(f"An error occurred during clustering: {str(e)}")
- st.session_state['clustering_error'] = str(e)
- st.session_state['clustering_completed'] = False
- st.session_state.active_tab_index = tabs_titles.index("Clustering")
- st.stop()
-
- # Display clustering results if they exist
- if st.session_state.get('clustering_completed', False):
- st.subheader("Topic Overview")
- dfc = st.session_state['clustered_data']
- topic_model = st.session_state['topic_model']
- topics = dfc['Topic'].tolist()
-
- unique_topics = sorted(list(set(topics)))
- cluster_info = []
- for t in unique_topics:
- cluster_docs = dfc[dfc['Topic'] == t]
- count = len(cluster_docs)
- top_words = topic_model.get_topic(t)
- if top_words:
- top_keywords = ", ".join([w[0] for w in top_words[:5]])
- else:
- top_keywords = "N/A"
- cluster_info.append((t, count, top_keywords))
- cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
-
- st.write("### Topic Overview")
- st.dataframe(
- cluster_df,
- column_config={
- "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
- "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
- "Top Keywords": st.column_config.TextColumn(
- "Top Keywords",
- help="Top 5 keywords that characterize this topic"
- )
- },
- hide_index=True
- )
-
- st.subheader("Clustering Results")
- columns_to_display = [c for c in dfc.columns if c != 'text']
- st.dataframe(dfc[columns_to_display], hide_index=True)
-
- # Display stored visualizations with error handling
- st.write("### Intertopic Distance Map")
- if st.session_state.get('intertopic_distance_fig') is not None:
- try:
- st.plotly_chart(st.session_state['intertopic_distance_fig'])
- except Exception:
- st.info("Topic visualization is not available for the current clustering results.")
-
- st.write("### Topic Document Visualization")
- if st.session_state.get('topic_document_fig') is not None:
- try:
- st.plotly_chart(st.session_state['topic_document_fig'])
- except Exception:
- st.info("Document visualization is not available for the current clustering results.")
-
- st.write("### Topic Hierarchy")
- if st.session_state.get('hierarchy_fig') is not None:
- try:
- st.plotly_chart(st.session_state['hierarchy_fig'])
- except Exception:
- st.info("Topic hierarchy visualization is not available for the current clustering results.")
- if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering):
- pass
- else:
- st.warning("Please select or upload a dataset and filter as needed.")
-
-
- ###############################################################################
- # Tab: Summarization
- ###############################################################################
- with tab_summarization:
- st.header("Summarization")
- # Add explanation about summarization
- with st.expander("ℹ️ How Summarization Works", expanded=False):
- st.markdown("""
- ### Understanding Document Summarization
-
- Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works:
-
- 1. **Summary Generation**:
- - Documents are processed using advanced language models
- - Key themes and important points are identified
- - Content is condensed while maintaining context
- - Both high-level and cluster-specific summaries are available
-
- 2. **Reference System**:
- - Summaries can include references to source documents
- - References are shown as [ID] or as clickable links
- - Each statement can be traced back to its source
- - Helps maintain accountability and verification
-
- 3. **Types of Summaries**:
- - **High-Level Summary**: Overview of all selected documents
- - Captures main themes across the entire selection
- - Ideal for quick understanding of large document sets
- - Shows relationships between different topics
-
- - **Cluster-Specific Summaries**: Focused on each cluster
- - More detailed for specific themes
- - Shows unique aspects of each cluster
- - Helps understand sub-topics in depth
-
- ### How to Use Summaries
-
- 1. **Configuration**:
- - Choose between all clusters or specific ones
- - Set temperature for creativity vs. consistency
- - Adjust max tokens for summary length
- - Enable/disable reference system
-
- 2. **Reference Options**:
- - Select column for reference IDs
- - Add hyperlinks to references
- - Choose URL column for clickable links
- - References help track information sources
-
- 3. **Practical Applications**:
- - Quick overview of large datasets
- - Detailed analysis of specific themes
- - Evidence-based reporting with references
- - Compare different document groups
-
- ### Tips for Better Results
-
- - **Temperature Setting**:
- - Higher (0.7-1.0): More creative, varied summaries
- - Lower (0.1-0.3): More consistent, conservative summaries
- - Balance based on your needs for creativity vs. consistency
-
- - **Token Length**:
- - Longer limits: More detailed summaries
- - Shorter limits: More concise, focused summaries
- - Adjust based on document complexity
-
- - **Reference Usage**:
- - Enable references for traceability
- - Use hyperlinks for easy navigation
- - Choose meaningful reference columns
- - Helps validate summary accuracy
-
- ### Best Practices
-
- 1. **For General Overview**:
- - Use high-level summary
- - Keep temperature moderate (0.5-0.7)
- - Enable references for verification
- - Focus on broader themes
-
- 2. **For Detailed Analysis**:
- - Use cluster-specific summaries
- - Adjust temperature based on need
- - Include references with hyperlinks
- - Look for patterns within clusters
-
- 3. **For Reporting**:
- - Combine both summary types
- - Use references extensively
- - Balance detail and brevity
- - Ensure source traceability
- """)
-
- df_summ = None
- # We'll try to summarize either the clustered data or just the filtered dataset
- if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
- df_summ = st.session_state['clustered_data']
- elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
- df_summ = st.session_state['filtered_df']
- else:
- st.warning("No data available for summarization. Please cluster first or have some filtered data.")
-
- if df_summ is not None and not df_summ.empty:
- text_columns = st.session_state.get('text_columns', [])
- if not text_columns:
- st.warning("No text columns selected. Please select columns for text embedding first.")
- else:
- if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state:
- st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.")
- else:
- topic_model = st.session_state['topic_model']
- df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1)
-
- # List of topics
- topics = sorted(df_summ['Topic'].unique())
- cluster_info = []
- for t in topics:
- cluster_docs = df_summ[df_summ['Topic'] == t]
- count = len(cluster_docs)
- top_words = topic_model.get_topic(t)
- if top_words:
- top_keywords = ", ".join([w[0] for w in top_words[:5]])
- else:
- top_keywords = "N/A"
- cluster_info.append((t, count, top_keywords))
- cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
-
- # If we have cluster names from previous summarization, add them
- if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
- summary_df = st.session_state['summary_df']
- # Create a mapping of topic to name for merging
- topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
- # Add cluster names to cluster_df
- cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
- # Reorder columns to show name after topic
- cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
-
- st.write("### Available Clusters:")
- st.dataframe(
- cluster_df,
- column_config={
- "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
- "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
- "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
- "Top Keywords": st.column_config.TextColumn(
- "Top Keywords",
- help="Top 5 keywords that characterize this topic"
- )
- },
- hide_index=True
- )
-
- # Summarization settings
- st.subheader("Summarization Settings")
- # Summaries scope
- summary_scope = st.radio(
- "Generate summaries for:",
- ["All clusters", "Specific clusters"]
- )
- if summary_scope == "Specific clusters":
- # Format options to include cluster names if available
- if 'Cluster_Name' in cluster_df.columns:
- topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])]
- topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])}
- selected_topic_options = st.multiselect("Select clusters to summarize", topic_options)
- selected_topics = [topic_to_id[opt] for opt in selected_topic_options]
- else:
- selected_topics = st.multiselect("Select clusters to summarize", topics)
- else:
- selected_topics = topics
-
- # Add system prompt configuration
- default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries.
- You will be given text and an objective context. Please produce a clear, cohesive,
- and thematically relevant summary.
- Focus on key points, insights, or patterns that emerge from the text."""
-
- if 'system_prompt' not in st.session_state:
- st.session_state['system_prompt'] = default_system_prompt
-
- with st.expander("🔧 Advanced Settings", expanded=False):
- st.markdown("""
- ### System Prompt Configuration
-
- The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs:
- - Be specific about the style and focus you want
- - Add domain-specific context if needed
- - Include any special formatting requirements
- """)
-
- system_prompt = st.text_area(
- "Customize System Prompt",
- value=st.session_state['system_prompt'],
- height=150,
- help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus."
- )
-
- if st.button("Reset to Default"):
- system_prompt = default_system_prompt
- st.session_state['system_prompt'] = default_system_prompt
-
- st.markdown("### Generation Parameters")
- temperature = st.slider(
- "Temperature",
- 0.0, 1.0, 0.7,
- help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent."
- )
- max_tokens = st.slider(
- "Max Tokens",
- 100, 3000, 1000,
- help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate."
- )
-
- st.session_state['system_prompt'] = system_prompt
-
- st.write("### Enhanced Summary References")
- st.write("Select columns for references (optional).")
- all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']]
-
- # By default, let's guess the first column as reference ID if available
- if 'reference_id_column' not in st.session_state:
- st.session_state.reference_id_column = all_cols[0] if all_cols else None
- # If there's a column that looks like a URL, guess that
- url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None)
- if 'url_column' not in st.session_state:
- st.session_state.url_column = url_guess
-
- enable_references = st.checkbox(
- "Enable references in summaries",
- value=True, # default to True as requested
- help="Add source references to the final summary text."
- )
- reference_id_column = st.selectbox(
- "Select column to use as reference ID:",
- all_cols,
- index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0
- )
- add_hyperlinks = st.checkbox(
- "Add hyperlinks to references",
- value=True, # default to True
- help="If the reference column has a matching URL, make it clickable."
- )
- url_column = None
- if add_hyperlinks:
- url_column = st.selectbox(
- "Select column containing URLs:",
- all_cols,
- index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0
- )
-
- # Summarization button
- if st.button("Generate Summaries"):
- openai_api_key = os.environ.get('OPENAI_API_KEY')
- if not openai_api_key:
- st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
- else:
- # Set flag to indicate summarization button was clicked
- st.session_state['_summarization_button_clicked'] = True
-
- llm = ChatOpenAI(
- api_key=openai_api_key,
- model_name='gpt-4o-mini', # or 'gpt-4o'
- temperature=temperature,
- max_tokens=max_tokens
- )
-
- # Filter to selected topics
- if selected_topics:
- df_scope = df_summ[df_summ['Topic'].isin(selected_topics)]
- else:
- st.warning("No topics selected for summarization.")
- df_scope = pd.DataFrame()
-
- if df_scope.empty:
- st.warning("No documents match the selected topics for summarization.")
- else:
- all_texts = df_scope['text'].tolist()
- combined_text = " ".join(all_texts)
- if not combined_text.strip():
- st.warning("No text data available for summarization.")
- else:
- # For cluster-specific summaries, use the customized prompt
- local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
- local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
- local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
-
- # Summaries per cluster
- # Only if multiple clusters are selected
- unique_selected_topics = df_scope['Topic'].unique()
- if len(unique_selected_topics) > 1:
- st.write("### Summaries per Selected Cluster")
-
- # Process summaries in parallel
- with st.spinner("Generating cluster summaries in parallel..."):
- summaries = process_summaries_in_parallel(
- df_scope=df_scope,
- unique_selected_topics=unique_selected_topics,
- llm=llm,
- chat_prompt=local_chat_prompt,
- enable_references=enable_references,
- reference_id_column=reference_id_column,
- url_column=url_column if add_hyperlinks else None,
- max_workers=min(16, len(unique_selected_topics)) # Limit workers based on clusters
- )
-
- if summaries:
- summary_df = pd.DataFrame(summaries)
- # Store the summaries DataFrame in session state
- st.session_state['summary_df'] = summary_df
- # Store additional summary info in session state
- st.session_state['has_references'] = enable_references
- st.session_state['reference_id_column'] = reference_id_column
- st.session_state['url_column'] = url_column if add_hyperlinks else None
-
- # Update cluster_df with new names
- if 'Cluster_Name' in summary_df.columns:
- topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
- cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
- cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
-
- # Immediately display updated cluster overview
- st.write("### Updated Topic Overview:")
- st.dataframe(
- cluster_df,
- column_config={
- "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
- "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
- "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
- "Top Keywords": st.column_config.TextColumn(
- "Top Keywords",
- help="Top 5 keywords that characterize this topic"
- )
- },
- hide_index=True
- )
-
- # Now generate high-level summary from the cluster summaries
- with st.spinner("Generating high-level summary from cluster summaries..."):
- # Format cluster summaries with proper markdown and HTML
- formatted_summaries = []
- total_tokens = 0
- MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) # Leave room for system prompt and completion
- summary_batches = []
- current_batch = []
- current_batch_tokens = 0
-
- for _, row in summary_df.iterrows():
- summary_text = row.get('Enhanced_Summary', row['Summary'])
- formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
- summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
-
- # If adding this summary would exceed the safe token limit, start a new batch
- if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
- if current_batch: # Only append if we have summaries in the current batch
- summary_batches.append(current_batch)
- current_batch = []
- current_batch_tokens = 0
-
- current_batch.append(formatted_summary)
- current_batch_tokens += summary_tokens
-
- # Add the last batch if it has any summaries
- if current_batch:
- summary_batches.append(current_batch)
-
- # Generate overview for each batch
- batch_overviews = []
- with st.spinner("Generating batch summaries..."):
- for i, batch in enumerate(summary_batches, 1):
- st.write(f"Processing batch {i} of {len(summary_batches)}...")
-
- batch_text = "\n\n".join(batch)
- batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
-
-Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
-1. Preserve all hyperlinked references exactly as they appear in the input summaries
-2. Maintain the HTML anchor tags () intact when using information from the summaries
-3. Keep the markdown formatting for better readability
-4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters
-
-Here are the cluster summaries to synthesize:
-
-{batch_text}"""
-
- # Generate overview for this batch
- high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
- high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
- high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message])
- high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
- batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
- batch_overviews.append(batch_overview)
-
- # Now combine the batch overviews
- with st.spinner("Generating final combined summary..."):
- combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)])
- final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents.
-
-Please create a final comprehensive synthesis that:
-1. Integrates the key themes and findings from all parts
-2. Preserves all hyperlinked references exactly as they appear
-3. Maintains the HTML anchor tags () intact
-4. Keeps the markdown formatting for better readability
-5. Creates a coherent narrative across all parts
-6. Highlights any themes that span multiple parts
-
-Here are the overviews to synthesize:
-
-### Part 1:
-
-{combined_overviews}"""
-
- # Verify the final prompt's token count
- final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
- if final_prompt_tokens > MAX_SAFE_TOKENS:
- st.error(f"❌ Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.")
- high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
- else:
- # Generate final synthesis
- high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
- high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
-
- # Store both versions of the summary
- st.session_state['high_level_summary'] = high_level_summary
- st.session_state['enhanced_summary'] = high_level_summary
-
- # Set flag to indicate summarization is complete
- st.session_state['summarization_completed'] = True
-
- # Update the display without rerunning
- st.write("### High-Level Summary:")
- st.markdown(high_level_summary, unsafe_allow_html=True)
-
- # Display cluster summaries
- st.write("### Cluster Summaries:")
- if enable_references and 'Enhanced_Summary' in summary_df.columns:
- for idx, row in summary_df.iterrows():
- cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
- st.write(f"**Topic {row['Topic']} - {cluster_name}**")
- st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
- st.write("---")
- with st.expander("View original summaries in table format"):
- display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
- display_df.columns = ['Topic', 'Cluster Name', 'Summary']
- st.dataframe(display_df, hide_index=True)
- else:
- st.write("### Summaries per Cluster:")
- if 'Cluster_Name' in summary_df.columns:
- display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
- display_df.columns = ['Topic', 'Cluster Name', 'Summary']
- st.dataframe(display_df, hide_index=True)
- else:
- st.dataframe(summary_df, hide_index=True)
-
- # Download
- if 'Enhanced_Summary' in summary_df.columns:
- dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
- dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
- else:
- dl_df = summary_df
- csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
- b64 = base64.b64encode(csv_bytes).decode()
- href = f'Download Summaries CSV'
- st.markdown(href, unsafe_allow_html=True)
-
- # Display existing summaries if available and summarization was completed
- if st.session_state.get('summarization_completed', False):
- if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
- if 'high_level_summary' in st.session_state:
- st.write("### High-Level Summary:")
- st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True)
-
- st.write("### Cluster Summaries:")
- summary_df = st.session_state['summary_df']
- if 'Enhanced_Summary' in summary_df.columns:
- for idx, row in summary_df.iterrows():
- cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
- st.write(f"**Topic {row['Topic']} - {cluster_name}**")
- st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
- st.write("---")
- with st.expander("View original summaries in table format"):
- display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
- display_df.columns = ['Topic', 'Cluster Name', 'Summary']
- st.dataframe(display_df, hide_index=True)
- else:
- st.dataframe(summary_df, hide_index=True)
-
- # Add download button for existing summaries
- dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
- if 'Cluster_Name' in dl_df.columns:
- dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
- csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
- b64 = base64.b64encode(csv_bytes).decode()
- href = f'Download Summaries CSV'
- st.markdown(href, unsafe_allow_html=True)
- else:
- st.warning("No data available for summarization.")
-
- # Display existing summaries if available (when returning to the tab)
- if not st.session_state.get('_summarization_button_clicked', False): # Only show if not just generated
- if 'high_level_summary' in st.session_state:
- st.write("### Existing High-Level Summary:")
- if st.session_state.get('enhanced_summary'):
- st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True)
- with st.expander("View original summary (without references)"):
- st.write(st.session_state['high_level_summary'])
- else:
- st.write(st.session_state['high_level_summary'])
-
- if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
- st.write("### Existing Cluster Summaries:")
- summary_df = st.session_state['summary_df']
- if 'Enhanced_Summary' in summary_df.columns:
- for idx, row in summary_df.iterrows():
- cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
- st.write(f"**Topic {row['Topic']} - {cluster_name}**")
- st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
- st.write("---")
- with st.expander("View original summaries in table format"):
- display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
- display_df.columns = ['Topic', 'Cluster Name', 'Summary']
- st.dataframe(display_df, hide_index=True)
- else:
- st.dataframe(summary_df, hide_index=True)
-
- # Add download button for existing summaries
- dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
- if 'Cluster_Name' in dl_df.columns:
- dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
- csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
- b64 = base64.b64encode(csv_bytes).decode()
- href = f'Download Summaries CSV'
- st.markdown(href, unsafe_allow_html=True)
-
-
- ###############################################################################
- # Tab: Chat
- ###############################################################################
- with tab_chat:
- st.header("Chat with Your Data")
-
- # Add explanation about chat functionality
- with st.expander("ℹ️ How Chat Works", expanded=False):
- st.markdown("""
- ### Understanding Chat with Your Data
-
- The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works:
-
- 1. **Data Selection**:
- - Choose which dataset to chat about (filtered, clustered, or search results)
- - Optionally focus on specific clusters if clustering was performed
- - System automatically includes relevant context from your selection
-
- 2. **Context Window**:
- - Shows how much of the GPT-4 context window is being used
- - Helps you understand if you need to filter data further
- - Displays token usage statistics
-
- 3. **Chat Features**:
- - Ask questions about your data
- - Get insights and analysis
- - Reference specific documents or clusters
- - Download chat context for transparency
-
- ### Best Practices
-
- 1. **Data Selection**:
- - Start with filtered or clustered data for more focused conversations
- - Select specific clusters if you want to dive deep into a topic
- - Consider the context window usage when selecting data
-
- 2. **Asking Questions**:
- - Be specific in your questions
- - Ask about patterns, trends, or insights
- - Reference clusters or documents by their IDs
- - Build on previous questions for deeper analysis
-
- 3. **Managing Context**:
- - Monitor the context window usage
- - Filter data further if context is too full
- - Download chat context for documentation
- - Clear chat history to start fresh
-
- ### Tips for Better Results
-
- - **Question Types**:
- - "What are the main themes in cluster 3?"
- - "Compare the findings between clusters 1 and 2"
- - "Summarize the methodology used across these documents"
- - "What are the common outcomes reported?"
-
- - **Follow-up Questions**:
- - Build on previous answers
- - Ask for clarification
- - Request specific examples
- - Explore relationships between findings
- """)
-
- # Function to check data source availability
- def get_available_data_sources():
- sources = []
- if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
- sources.append("Filtered Dataset")
- if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
- sources.append("Clustered Data")
- if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
- sources.append("Search Results")
- if ('high_level_summary' in st.session_state or
- ('summary_df' in st.session_state and not st.session_state['summary_df'].empty)):
- sources.append("Summarized Data")
- return sources
-
- # Get available data sources
- available_sources = get_available_data_sources()
-
- if not available_sources:
- st.warning("No data available for chat. Please filter, cluster, search, or summarize first.")
- st.stop()
-
- # Initialize or update data source in session state
- if 'chat_data_source' not in st.session_state:
- st.session_state.chat_data_source = available_sources[0]
- elif st.session_state.chat_data_source not in available_sources:
- st.session_state.chat_data_source = available_sources[0]
-
- # Data source selection with automatic fallback
- data_source = st.radio(
- "Select data to chat about:",
- available_sources,
- index=available_sources.index(st.session_state.chat_data_source),
- help="Choose which dataset you want to analyze in the chat."
- )
-
- # Update session state if data source changed
- if data_source != st.session_state.chat_data_source:
- st.session_state.chat_data_source = data_source
- # Clear any cluster-specific selections if switching data sources
- if 'chat_selected_cluster' in st.session_state:
- del st.session_state.chat_selected_cluster
-
- # Get the appropriate DataFrame based on selected source
- df_chat = None
- if data_source == "Filtered Dataset":
- df_chat = st.session_state['filtered_df']
- elif data_source == "Clustered Data":
- df_chat = st.session_state['clustered_data']
- elif data_source == "Search Results":
- df_chat = st.session_state['search_results']
- elif data_source == "Summarized Data":
- # Create DataFrame with selected summaries
- summary_rows = []
-
- # Add high-level summary if available
- if 'high_level_summary' in st.session_state:
- summary_rows.append({
- 'Summary_Type': 'High-Level Summary',
- 'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary'])
- })
-
- # Add cluster summaries if available
- if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
- summary_df = st.session_state['summary_df']
- for _, row in summary_df.iterrows():
- summary_rows.append({
- 'Summary_Type': f"Cluster {row['Topic']} Summary",
- 'Content': row.get('Enhanced_Summary', row['Summary'])
- })
-
- if summary_rows:
- df_chat = pd.DataFrame(summary_rows)
-
- if df_chat is not None and not df_chat.empty:
- # If we have clustered data, allow cluster selection
- selected_cluster = None
- if data_source != "Summarized Data" and 'Topic' in df_chat.columns:
- cluster_option = st.radio(
- "Choose cluster scope:",
- ["All Clusters", "Specific Cluster"]
- )
- if cluster_option == "Specific Cluster":
- unique_topics = sorted(df_chat['Topic'].unique())
- # Check if we have cluster names
- if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
- summary_df = st.session_state['summary_df']
- # Create a mapping of topic to name
- topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
- # Format the selectbox options
- topic_options = [
- (t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}")
- for t in unique_topics
- ]
- selected_cluster = st.selectbox(
- "Select cluster to focus on:",
- [t[0] for t in topic_options],
- format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x)
- )
- else:
- selected_cluster = st.selectbox(
- "Select cluster to focus on:",
- unique_topics,
- format_func=lambda x: f"Cluster {x}"
- )
- if selected_cluster is not None:
- df_chat = df_chat[df_chat['Topic'] == selected_cluster]
- st.session_state.chat_selected_cluster = selected_cluster
- elif 'chat_selected_cluster' in st.session_state:
- del st.session_state.chat_selected_cluster
-
- # Prepare the data for chat context
- text_columns = st.session_state.get('text_columns', [])
- if not text_columns and data_source != "Summarized Data":
- st.warning("No text columns selected. Please select text columns to enable chat functionality.")
- st.stop()
-
- # Instead of limiting to 210 documents, we'll limit by tokens
- MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95) # 95% of context window
-
- # Prepare system message first to account for its tokens
- system_msg = {
- "role": "system",
- "content": """You are a specialized assistant analyzing data from a research database.
- Your role is to:
- 1. Provide clear, concise answers based on the data provided
- 2. Highlight relevant information from specific results when answering
- 3. When referencing specific results, use their row index or ID if available
- 4. Clearly state if information is not available in the results
- 5. Maintain a professional and analytical tone
- 6. Format your responses using Markdown:
- - Use **bold** for emphasis
- - Use bullet points and numbered lists for structured information
- - Create tables using Markdown syntax when presenting structured data
- - Use backticks for code or technical terms
- - Include hyperlinks when referencing external sources
- - Use headings (###) to organize long responses
-
- The data is provided in a structured format where:""" + ("""
- - Each result contains multiple fields
- - Text content is primarily in the following columns: """ + ", ".join(text_columns) + """
- - Additional metadata and fields are available for reference
- - If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """
- - The data consists of AI-generated summaries of the documents
- - Each summary may contain references to source documents in markdown format
- - References are shown as [ID] or as clickable hyperlinks
- - Summaries may be high-level (covering all documents) or cluster-specific""") + """
- """
- }
-
- # Calculate system message tokens
- system_tokens = len(tokenizer(system_msg["content"])["input_ids"])
- remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens
-
- # Prepare the data context with token limiting
- data_text = "Available Data:\n"
- included_rows = 0
- total_rows = len(df_chat)
-
- if data_source == "Summarized Data":
- # For summarized data, process row by row
- for idx, row in df_chat.iterrows():
- row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n"
- row_tokens = len(tokenizer(row_text)["input_ids"])
-
- if remaining_tokens - row_tokens > 0:
- data_text += row_text
- remaining_tokens -= row_tokens
- included_rows += 1
- else:
- break
- else:
- # For regular data, process row by row
- for idx, row in df_chat.iterrows():
- row_text = f"\nItem {idx}:\n"
- for col in df_chat.columns:
- if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score':
- row_text += f"{col}: {row[col]}\n"
-
- row_tokens = len(tokenizer(row_text)["input_ids"])
- if remaining_tokens - row_tokens > 0:
- data_text += row_text
- remaining_tokens -= row_tokens
- included_rows += 1
- else:
- break
-
- # Calculate token usage
- data_tokens = len(tokenizer(data_text)["input_ids"])
- total_tokens = system_tokens + data_tokens
- context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100
-
- # Display token usage and data coverage
- st.subheader("Context Window Usage")
- st.write(f"System Message: {system_tokens:,} tokens")
- st.write(f"Data Context: {data_tokens:,} tokens")
- st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)")
- st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)")
-
- if context_usage_percent > 90:
- st.warning("⚠️ High context usage! Consider reducing the number of results or filtering further.")
- elif context_usage_percent > 75:
- st.info("ℹ️ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.")
-
- # Add download button for chat context
- chat_context = f"""System Message:
- {system_msg['content']}
-
- {data_text}"""
- st.download_button(
- label="📥 Download Chat Context",
- data=chat_context,
- file_name="chat_context.txt",
- mime="text/plain",
- help="Download the exact context that the chatbot receives"
- )
-
- # Chat interface
- col_chat1, col_chat2 = st.columns([3, 1])
- with col_chat1:
- user_input = st.text_area("Ask a question about your data:", key="chat_input")
- with col_chat2:
- if st.button("Clear Chat History"):
- st.session_state.chat_history = []
- st.rerun()
-
- # Store current tab index before processing
- current_tab = tabs_titles.index("Chat")
-
- if st.button("Send", key="send_button"):
- if user_input:
- # Set the active tab index to stay on Chat
- st.session_state.active_tab_index = current_tab
-
- with st.spinner("Processing your question..."):
- # Add user's question to chat history
- st.session_state.chat_history.append({"role": "user", "content": user_input})
-
- # Prepare messages for API call
- messages = [system_msg]
- messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"})
-
- # Get response from OpenAI
- response = get_chat_response(messages)
-
- if response:
- st.session_state.chat_history.append({"role": "assistant", "content": response})
-
- # Display chat history
- st.subheader("Chat History")
- for message in st.session_state.chat_history:
- if message["role"] == "user":
- st.write("**You:**", message["content"])
- else:
- st.write("**Assistant:**")
- st.markdown(message["content"], unsafe_allow_html=True)
- st.write("---") # Add a separator between messages
-
-
- ###############################################################################
- # Tab: Internal Validation
- ###############################################################################
-
-else: # Simple view
- st.header("Automatic Mode")
-
- # Initialize session state for automatic view
- if 'df' not in st.session_state:
- default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
- df = load_default_dataset(default_dataset_path)
- if df is not None:
- st.session_state['df'] = df.copy()
- st.session_state['using_default_dataset'] = True
- st.session_state['filtered_df'] = df.copy()
-
- # Set default text columns if not already set
- if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
- default_text_cols = []
- if 'Title' in df.columns and 'Description' in df.columns:
- default_text_cols = ['Title', 'Description']
- st.session_state['text_columns'] = default_text_cols
-
- # Single search bar for automatic processing
- #st.write("Enter your query to automatically search, cluster, and summarize the results:")
- query = st.text_input("Write your query here:")
-
-
-
-
- if st.button("SNAP!"):
- if query.strip():
- # Step 1: Semantic Search
- st.write("### Step 1: Semantic Search")
- with st.spinner("Performing Semantic Search..."):
- text_columns = st.session_state.get('text_columns', [])
- if text_columns:
- df_full = st.session_state['df']
- embeddings, _ = load_or_compute_embeddings(
- df_full,
- st.session_state.get('using_default_dataset', False),
- st.session_state.get('uploaded_file_name'),
- text_columns
- )
-
- if embeddings is not None:
- model = get_embedding_model()
- df_filtered = st.session_state['filtered_df'].fillna("")
- search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
-
- subset_indices = df_filtered.index
- subset_embeddings = embeddings[subset_indices]
-
- query_embedding = model.encode([query], device=device)
- similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
-
- similarity_threshold = 0.35 # Default threshold
- above_threshold_indices = np.where(similarities > similarity_threshold)[0]
-
- if len(above_threshold_indices) > 0:
- selected_indices = subset_indices[above_threshold_indices]
- results = df_filtered.loc[selected_indices].copy()
- results['similarity_score'] = similarities[above_threshold_indices]
- results.sort_values(by='similarity_score', ascending=False, inplace=True)
- st.session_state['search_results'] = results.copy()
- st.write(f"Found {len(results)} relevant documents")
- else:
- st.warning("No results found above the similarity threshold.")
- st.stop()
-
- # Step 2: Clustering
- if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
- st.write("### Step 2: Clustering")
- with st.spinner("Performing clustering..."):
- df_to_cluster = st.session_state['search_results'].copy()
- dfc = df_to_cluster.copy().fillna("")
- dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
-
- # Filter embeddings to those rows
- selected_indices = dfc.index
- embeddings_clustering = embeddings[selected_indices]
-
- # Basic cleaning
- stop_words = set(stopwords.words('english'))
- texts_cleaned = []
- for text in dfc['text'].tolist():
- try:
- word_tokens = word_tokenize(text)
- filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
- texts_cleaned.append(filtered_text)
- except Exception as e:
- texts_cleaned.append(text)
-
- min_cluster_size = 5 # Default value
-
- try:
- # Convert embeddings to CPU numpy if needed
- if torch.is_tensor(embeddings_clustering):
- embeddings_for_clustering = embeddings_clustering.cpu().numpy()
- else:
- embeddings_for_clustering = embeddings_clustering
-
- # Build the HDBSCAN model
- hdbscan_model = HDBSCAN(
- min_cluster_size=min_cluster_size,
- metric='euclidean',
- cluster_selection_method='eom'
- )
-
- # Build the BERTopic model
- topic_model = BERTopic(
- embedding_model=get_embedding_model(),
- hdbscan_model=hdbscan_model
- )
-
- # Fit the model and get topics
- topics, probs = topic_model.fit_transform(
- texts_cleaned,
- embeddings=embeddings_for_clustering
- )
-
- # Store results
- dfc['Topic'] = topics
- st.session_state['topic_model'] = topic_model
- st.session_state['clustered_data'] = dfc.copy()
- st.session_state['clustering_completed'] = True
-
- # Display clustering results summary
- unique_topics = sorted(list(set(topics)))
- num_clusters = len([t for t in unique_topics if t != -1]) # Exclude noise cluster (-1)
- noise_docs = len([t for t in topics if t == -1])
- clustered_docs = len(topics) - noise_docs
-
- st.write(f"Found {num_clusters} distinct clusters")
- #st.write(f"Documents successfully clustered: {clustered_docs}")
- #if noise_docs > 0:
- # st.write(f"Documents not fitting in any cluster: {noise_docs}")
-
- # Show quick cluster overview
- cluster_info = []
- for t in unique_topics:
- if t != -1: # Skip noise cluster in the overview
- cluster_docs = dfc[dfc['Topic'] == t]
- count = len(cluster_docs)
- top_words = topic_model.get_topic(t)
- top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
- cluster_info.append((t, count, top_keywords))
-
- if cluster_info:
- #st.write("### Quick Cluster Overview:")
- cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
- # st.dataframe(
- # cluster_df,
- # column_config={
- # "Topic": st.column_config.NumberColumn("Topic", help="Topic ID"),
- # "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
- # "Top Keywords": st.column_config.TextColumn(
- # "Top Keywords",
- # help="Top 5 keywords that characterize this topic"
- # )
- # },
- # hide_index=True
- # )
-
- # Generate visualizations
- try:
- st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
- except Exception:
- st.session_state['intertopic_distance_fig'] = None
-
- try:
- st.session_state['topic_document_fig'] = topic_model.visualize_documents(
- texts_cleaned,
- embeddings=embeddings_for_clustering
- )
- except Exception:
- st.session_state['topic_document_fig'] = None
-
- try:
- hierarchy = topic_model.hierarchical_topics(texts_cleaned)
- st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
- st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
- except Exception:
- st.session_state['hierarchy'] = pd.DataFrame()
- st.session_state['hierarchy_fig'] = None
-
- except Exception as e:
- st.error(f"An error occurred during clustering: {str(e)}")
- st.stop()
-
- # Step 3: Summarization
- if st.session_state.get('clustering_completed', False):
- st.write("### Step 3: Summarization")
-
- # Initialize OpenAI client
- openai_api_key = os.environ.get('OPENAI_API_KEY')
- if not openai_api_key:
- st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
- st.stop()
-
- llm = ChatOpenAI(
- api_key=openai_api_key,
- model_name='gpt-4o-mini',
- temperature=0.7,
- max_tokens=1000
- )
-
- df_scope = st.session_state['clustered_data']
- unique_selected_topics = df_scope['Topic'].unique()
-
- # Process summaries in parallel
- with st.spinner("Generating summaries..."):
- local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries.
-You will be given text and an objective context. Please produce a clear, cohesive,
-and thematically relevant summary.
-Focus on key points, insights, or patterns that emerge from the text.""")
- local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
- local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
-
- # Find URL column if it exists
- url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None)
-
- summaries = process_summaries_in_parallel(
- df_scope=df_scope,
- unique_selected_topics=unique_selected_topics,
- llm=llm,
- chat_prompt=local_chat_prompt,
- enable_references=True,
- reference_id_column=df_scope.columns[0],
- url_column=url_column, # Add URL column for clickable links
- max_workers=min(16, len(unique_selected_topics))
- )
-
- if summaries:
- summary_df = pd.DataFrame(summaries)
- st.session_state['summary_df'] = summary_df
-
- # Display updated cluster overview
- if 'Cluster_Name' in summary_df.columns:
- st.write("### Updated Topic Overview:")
- cluster_info = []
- for t in unique_selected_topics:
- cluster_docs = df_scope[df_scope['Topic'] == t]
- count = len(cluster_docs)
- top_words = topic_model.get_topic(t)
- top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
- cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0]
- cluster_info.append((t, cluster_name, count, top_keywords))
-
- cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"])
- st.dataframe(
- cluster_df,
- column_config={
- "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
- "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
- "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
- "Top Keywords": st.column_config.TextColumn(
- "Top Keywords",
- help="Top 5 keywords that characterize this topic"
- )
- },
- hide_index=True
- )
-
- # Generate and display high-level summary
- with st.spinner("Generating high-level summary..."):
- formatted_summaries = []
- summary_batches = []
- current_batch = []
- current_batch_tokens = 0
- MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75)
-
- for _, row in summary_df.iterrows():
- summary_text = row.get('Enhanced_Summary', row['Summary'])
- formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
- summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
-
- if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
- if current_batch:
- summary_batches.append(current_batch)
- current_batch = []
- current_batch_tokens = 0
-
- current_batch.append(formatted_summary)
- current_batch_tokens += summary_tokens
-
- if current_batch:
- summary_batches.append(current_batch)
-
- # Process each batch separately first
- batch_overviews = []
- for i, batch in enumerate(summary_batches, 1):
- st.write(f"Processing summary batch {i} of {len(summary_batches)}...")
- batch_text = "\n\n".join(batch)
- batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
-
-Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
-1. Preserve all hyperlinked references exactly as they appear in the input summaries
-2. Maintain the HTML anchor tags () intact when using information from the summaries
-3. Keep the markdown formatting for better readability
-4. Create clear sections with headings for different themes
-5. Use bullet points or numbered lists where appropriate
-6. Focus on synthesizing the main themes and findings
-
-Here are the cluster summaries to synthesize:
-
-{batch_text}"""
-
- high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
- batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
- batch_overviews.append(batch_overview)
-
- # Now create the final synthesis
- if len(batch_overviews) > 1:
- st.write("Generating final synthesis...")
- combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
- final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents.
-
-Please create a final comprehensive synthesis that:
-1. Integrates the key themes and findings from all parts into a cohesive narrative
-2. Preserves all hyperlinked references exactly as they appear
-3. Maintains the HTML anchor tags () intact
-4. Uses clear section headings and structured formatting
-5. Highlights cross-cutting themes and relationships between different aspects
-6. Provides a clear introduction and conclusion
-
-Here are the overviews to synthesize:
-
-# Part 1
-
-{combined_overviews}"""
-
- final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
- if final_prompt_tokens > MAX_SAFE_TOKENS:
- # If too long, just combine with headers
- high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
- else:
- high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
- high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
- else:
- # If only one batch, use its overview directly
- high_level_summary = batch_overviews[0]
-
- st.session_state['high_level_summary'] = high_level_summary
- st.session_state['enhanced_summary'] = high_level_summary
-
- # Display summaries
- st.write("### High-Level Summary:")
- with st.expander("High-Level Summary", expanded=True):
- st.markdown(high_level_summary, unsafe_allow_html=True)
-
- st.write("### Cluster Summaries:")
- for idx, row in summary_df.iterrows():
- cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
- with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False):
- st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True)
- st.markdown("##### About this tool")
- with st.expander("Click to expand/collapse", expanded=True):
- st.markdown("""
- This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on.
-
- **Tips:**
- - **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`).
- - Avoid writing full questions — **this is not a chatbot**.
- - Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`).
- - Focus on **concepts or themes** — not single words like `"climate"` or `"yield"` alone.
- - Example good queries:
- - `"climate adaptation smallholder farming"`
- - `"digital agriculture innovations"`
- - `"nutrition-sensitive value chains"`
-
- **Example use case**:
- You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**.
- A good search phrase would be:
- 👉 `"poverty reduction maize Africa"`
- This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*.
+# app.py
+
+import streamlit as st
+
+# Set page config first, before any other st commands
+st.set_page_config(page_title="SNAP", layout="wide")
+
+# Add warning filters
+import warnings
+# More specific warning filters for torch.classes
+warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*')
+warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*')
+
+import pandas as pd
+import numpy as np
+import os
+import io
+import time
+from datetime import datetime
+import base64
+import re
+import pickle
+from typing import List, Dict, Any, Tuple
+import plotly.express as px
+import torch
+
+# For parallelism
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+
+# Import necessary libraries for embeddings, clustering, and summarization
+from sentence_transformers import SentenceTransformer
+from sklearn.metrics.pairwise import cosine_similarity
+from bertopic import BERTopic
+from hdbscan import HDBSCAN
+import nltk
+from nltk.corpus import stopwords
+from nltk.tokenize import word_tokenize
+
+# For summarization and chat
+from langchain.chains import LLMChain
+from langchain_community.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
+from openai import OpenAI
+from transformers import GPT2TokenizerFast
+
+# Initialize OpenAI client and tokenizer
+client = OpenAI()
+
+###############################################################################
+# Helper: Attempt to get this file's directory or fallback to current working dir
+###############################################################################
+def get_base_dir():
+ try:
+ base_dir = os.path.dirname(__file__)
+ if not base_dir:
+ return os.getcwd()
+ return base_dir
+ except NameError:
+ # In case __file__ is not defined (some environments)
+ return os.getcwd()
+
+BASE_DIR = get_base_dir()
+
+# Function to get or create model directory
+def get_model_dir():
+ base_dir = get_base_dir()
+ model_dir = os.path.join(base_dir, 'models')
+ os.makedirs(model_dir, exist_ok=True)
+ return model_dir
+
+# Function to load tokenizer from local storage or download
+def load_tokenizer():
+ model_dir = get_model_dir()
+ tokenizer_dir = os.path.join(model_dir, 'tokenizer')
+ os.makedirs(tokenizer_dir, exist_ok=True)
+
+ try:
+ # Try to load from local directory first
+ tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir)
+ #st.success("Loaded tokenizer from local storage")
+ except Exception as e:
+ #st.warning("Downloading tokenizer (one-time operation)...")
+ try:
+ # Download and save to local directory
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Use standard GPT2 tokenizer
+ tokenizer.save_pretrained(tokenizer_dir)
+ #st.success("Downloaded and saved tokenizer")
+ except Exception as download_e:
+ #st.error(f"Error downloading tokenizer: {str(download_e)}")
+ raise
+
+ return tokenizer
+
+# Load tokenizer
+try:
+ tokenizer = load_tokenizer()
+except Exception as e:
+ #st.error("Failed to load tokenizer. Some functionality may be limited.")
+ tokenizer = None
+
+MAX_CONTEXT_WINDOW = 128000 # GPT-4o context window size
+
+# Initialize chat history in session state if not exists
+if 'chat_history' not in st.session_state:
+ st.session_state.chat_history = []
+
+###############################################################################
+# Helper: Get chat response from OpenAI
+###############################################################################
+def get_chat_response(messages):
+ try:
+ response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=messages,
+ temperature=0,
+ )
+ return response.choices[0].message.content.strip()
+ except Exception as e:
+ st.error(f"Error querying OpenAI: {e}")
+ return None
+
+###############################################################################
+# Helper: Generate raw summary for a cluster (without references)
+###############################################################################
+def generate_raw_cluster_summary(
+ topic_val: int,
+ cluster_df: pd.DataFrame,
+ llm: Any,
+ chat_prompt: Any
+) -> Dict[str, Any]:
+ """Generate a summary for a single cluster without reference enhancement,
+ automatically trimming text if it exceeds a safe token limit."""
+ cluster_text = " ".join(cluster_df['text'].tolist())
+ if not cluster_text.strip():
+ return None
+
+ # Define a safe limit (95% of max context window to leave room for prompts)
+ safe_limit = int(MAX_CONTEXT_WINDOW * 0.95)
+
+ # Encode the text into tokens
+ encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False)
+
+ # If the text is too large, slice it
+ if len(encoded_text) > safe_limit:
+ #st.warning(f"Cluster {topic_val} text is too large ({len(encoded_text)} tokens). Trimming to {safe_limit} tokens.")
+ encoded_text = encoded_text[:safe_limit]
+ cluster_text = tokenizer.decode(encoded_text)
+
+ user_prompt_local = f"**Text to summarize**: {cluster_text}"
+ try:
+ local_chain = LLMChain(llm=llm, prompt=chat_prompt)
+ summary_local = local_chain.run(user_prompt=user_prompt_local).strip()
+ return {'Topic': topic_val, 'Summary': summary_local}
+ except Exception as e:
+ st.error(f"Error generating summary for cluster {topic_val}: {str(e)}")
+ return None
+
+###############################################################################
+# Helper: Enhance a summary with references
+###############################################################################
+def enhance_summary_with_references(
+ summary_dict: Dict[str, Any],
+ df_scope: pd.DataFrame,
+ reference_id_column: str,
+ url_column: str = None,
+ llm: Any = None
+) -> Dict[str, Any]:
+ """Add references to a summary."""
+ if not summary_dict or 'Summary' not in summary_dict:
+ return summary_dict
+
+ try:
+ cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']]
+ enhanced = add_references_to_summary(
+ summary_dict['Summary'],
+ cluster_df,
+ reference_id_column,
+ url_column,
+ llm
+ )
+ summary_dict['Enhanced_Summary'] = enhanced
+ return summary_dict
+ except Exception as e:
+ st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}")
+ return summary_dict
+
+###############################################################################
+# Helper: Process summaries in parallel
+###############################################################################
+def process_summaries_in_parallel(
+ df_scope: pd.DataFrame,
+ unique_selected_topics: List[int],
+ llm: Any,
+ chat_prompt: Any,
+ enable_references: bool = False,
+ reference_id_column: str = None,
+ url_column: str = None,
+ max_workers: int = 16
+) -> List[Dict[str, Any]]:
+ """Process multiple cluster summaries in parallel using ThreadPoolExecutor."""
+ summaries = []
+ total_topics = len(unique_selected_topics)
+
+ # Create progress placeholders
+ progress_text = st.empty()
+ progress_bar = st.progress(0)
+
+ try:
+ # Phase 1: Generate raw summaries in parallel
+ progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)")
+ completed_summaries = 0
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit summary generation tasks
+ future_to_topic = {
+ executor.submit(
+ generate_raw_cluster_summary,
+ topic_val,
+ df_scope[df_scope['Topic'] == topic_val],
+ llm,
+ chat_prompt
+ ): topic_val
+ for topic_val in unique_selected_topics
+ }
+
+ # Process completed summary tasks
+ for future in future_to_topic:
+ try:
+ result = future.result()
+ if result:
+ summaries.append(result)
+ completed_summaries += 1
+ # Update progress
+ progress = completed_summaries / total_topics
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_topic[future]
+ st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}")
+ completed_summaries += 1
+ continue
+
+ # Phase 2: Enhance summaries with references in parallel (if enabled)
+ if enable_references and reference_id_column and summaries:
+ total_to_enhance = len(summaries)
+ completed_enhancements = 0
+ progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)")
+ progress_bar.progress(0)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit reference enhancement tasks
+ future_to_summary = {
+ executor.submit(
+ enhance_summary_with_references,
+ summary_dict,
+ df_scope,
+ reference_id_column,
+ url_column,
+ llm
+ ): summary_dict.get('Topic')
+ for summary_dict in summaries
+ }
+
+ # Process completed enhancement tasks
+ enhanced_summaries = []
+ for future in future_to_summary:
+ try:
+ result = future.result()
+ if result:
+ enhanced_summaries.append(result)
+ completed_enhancements += 1
+ # Update progress
+ progress = completed_enhancements / total_to_enhance
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_summary[future]
+ st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}")
+ completed_enhancements += 1
+ continue
+
+ summaries = enhanced_summaries
+
+ # Phase 3: Generate cluster names in parallel
+ if summaries:
+ total_to_name = len(summaries)
+ completed_names = 0
+ progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)")
+ progress_bar.progress(0)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit cluster naming tasks
+ future_to_summary = {
+ executor.submit(
+ generate_cluster_name,
+ summary_dict.get('Enhanced_Summary', summary_dict['Summary']),
+ llm
+ ): summary_dict.get('Topic')
+ for summary_dict in summaries
+ }
+
+ # Process completed naming tasks
+ named_summaries = []
+ for future in future_to_summary:
+ try:
+ cluster_name = future.result()
+ topic_val = future_to_summary[future]
+ # Find the corresponding summary dict
+ summary_dict = next(s for s in summaries if s['Topic'] == topic_val)
+ summary_dict['Cluster_Name'] = cluster_name
+ named_summaries.append(summary_dict)
+ completed_names += 1
+ # Update progress
+ progress = completed_names / total_to_name
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_summary[future]
+ st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}")
+ completed_names += 1
+ continue
+
+ summaries = named_summaries
+ finally:
+ # Clean up progress indicators
+ progress_text.empty()
+ progress_bar.empty()
+
+ return summaries
+
+###############################################################################
+# Helper: Generate cluster name
+###############################################################################
+def generate_cluster_name(summary_text: str, llm: Any) -> str:
+ """Generate a concise, descriptive name for a cluster based on its summary."""
+ system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster.
+
+Rules:
+1. Keep it between 3-6 words
+2. Be specific but concise
+3. Capture the main theme/focus
+4. Use title case
+4. Do not include words like "Cluster", "Topic", or "Theme"
+5. Focus on the content, not metadata
+
+Example good names:
+- Agricultural Water Management Innovation
+- Gender Equality in Farming
+- Climate-Smart Village Implementation
+- Sustainable Livestock Practices"""
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"}
+ ]
+
+ try:
+ response = get_chat_response(messages)
+ # Clean up response (remove quotes, newlines, etc.)
+ cluster_name = response.strip().strip('"').strip("'").strip()
+ return cluster_name
+ except Exception as e:
+ st.error(f"Error generating cluster name: {str(e)}")
+ return "Unnamed Cluster"
+
+###############################################################################
+# Helper: Attempt to get this file's directory or fallback to current working dir
+###############################################################################
+def get_base_dir():
+ try:
+ base_dir = os.path.dirname(__file__)
+ if not base_dir:
+ return os.getcwd()
+ return base_dir
+ except NameError:
+ # In case __file__ is not defined (some environments)
+ return os.getcwd()
+
+BASE_DIR = get_base_dir()
+
+###############################################################################
+# NLTK Resource Initialization
+###############################################################################
+def init_nltk_resources():
+ """Initialize NLTK resources with better error handling and less verbose output"""
+ nltk.data.path.append('/home/appuser/nltk_data') # Ensure consistent data path
+
+ resources = {
+ 'tokenizers/punkt': 'punkt_tab', # Updated to use punkt_tab
+ 'corpora/stopwords': 'stopwords'
+ }
+
+ for resource_path, resource_name in resources.items():
+ try:
+ nltk.data.find(resource_path)
+ except LookupError:
+ try:
+ nltk.download(resource_name, quiet=True)
+ except Exception as e:
+ st.warning(f"Error downloading NLTK resource {resource_name}: {e}")
+
+ # Test tokenizer silently
+ try:
+ from nltk.tokenize import PunktSentenceTokenizer
+ tokenizer = PunktSentenceTokenizer()
+ tokenizer.tokenize("Test sentence.")
+ except Exception as e:
+ st.error(f"Error initializing NLTK tokenizer: {e}")
+ try:
+ nltk.download('punkt_tab', quiet=True) # Updated to use punkt_tab
+ except Exception as e:
+ st.error(f"Failed to download punkt_tab tokenizer: {e}")
+
+# Initialize NLTK resources
+init_nltk_resources()
+
+###############################################################################
+# Function: add_references_to_summary
+###############################################################################
+def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None):
+ """
+ Add references to a summary by identifying which parts of the summary come
+ from which source documents. References will be appended as [ID],
+ optionally linked if a URL column is provided.
+
+ Args:
+ summary (str): The summary text to enhance with references.
+ source_df (DataFrame): DataFrame containing the source documents.
+ reference_column (str): Column name to use for reference IDs.
+ url_column (str, optional): Column name containing URLs for hyperlinks.
+ llm (LLM, optional): Language model for source attribution.
+ Returns:
+ str: Enhanced summary with references as HTML if possible.
+ """
+ if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns:
+ return summary
+
+ # If no LLM is provided, we can't do source attribution
+ if llm is None:
+ return summary
+
+ # Split the summary into paragraphs first
+ paragraphs = summary.split('\n\n')
+ enhanced_paragraphs = []
+
+ # Prepare source texts with their reference IDs
+ source_texts = []
+ reference_ids = []
+ urls = []
+ for _, row in source_df.iterrows():
+ if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]):
+ source_texts.append(str(row['text']))
+ reference_ids.append(str(row[reference_column]))
+ if url_column and url_column in row and pd.notna(row[url_column]):
+ urls.append(str(row[url_column]))
+ else:
+ urls.append(None)
+ if not source_texts:
+ return summary
+
+ # Create a mapping between URLs and reference IDs
+ url_map = {}
+ for ref_id, u in zip(reference_ids, urls):
+ if u:
+ url_map[ref_id] = u
+
+ # Define the system prompt for source attribution
+ system_prompt = """
+ You are an expert at identifying the source of information. You will be given:
+ 1. A sentence or bullet point from a summary
+ 2. A list of source texts with their IDs
+
+ Your task is to identify which source text(s) the text most likely came from.
+ Return ONLY the IDs of the source texts that contributed to the text, separated by commas.
+ If you cannot confidently attribute the text to any source, return "unknown".
+ """
+
+ for paragraph in paragraphs:
+ if not paragraph.strip():
+ enhanced_paragraphs.append('')
+ continue
+
+ # Check if it's a bullet point list
+ if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')):
+ # Handle bullet points
+ bullet_lines = paragraph.split('\n')
+ enhanced_bullets = []
+ for line in bullet_lines:
+ if not line.strip():
+ enhanced_bullets.append(line)
+ continue
+
+ if line.strip().startswith('- ') or line.strip().startswith('* '):
+ # Process each bullet point
+ user_prompt = f"""
+ Text: {line.strip()}
+
+ Source texts:
+ {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
+
+ Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown".
+ """
+
+ try:
+ system_message = SystemMessagePromptTemplate.from_template(system_prompt)
+ human_message = HumanMessagePromptTemplate.from_template({user_prompt})
+ chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
+ chain = LLMChain(llm=llm, prompt=chat_prompt)
+ response = chain.run(user_prompt=user_prompt)
+ source_ids = response.strip()
+
+ if source_ids.lower() == "unknown":
+ enhanced_bullets.append(line)
+ else:
+ # Extract just the IDs
+ source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
+ source_ids = re.sub(r'\s+', '', source_ids)
+ ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
+
+ if ids:
+ ref_parts = []
+ for id_ in ids:
+ if id_ in url_map:
+ ref_parts.append(f'{id_}')
+ else:
+ ref_parts.append(id_)
+ ref_string = ", ".join(ref_parts)
+ enhanced_bullets.append(f"{line} [{ref_string}]")
+ else:
+ enhanced_bullets.append(line)
+ except Exception:
+ enhanced_bullets.append(line)
+ else:
+ enhanced_bullets.append(line)
+
+ enhanced_paragraphs.append('\n'.join(enhanced_bullets))
+ else:
+ # Handle regular paragraphs
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
+ enhanced_sentences = []
+
+ for sentence in sentences:
+ if not sentence.strip():
+ continue
+
+ user_prompt = f"""
+ Sentence: {sentence.strip()}
+
+ Source texts:
+ {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
+
+ Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown".
+ """
+
+ try:
+ system_message = SystemMessagePromptTemplate.from_template(system_prompt)
+ human_message = HumanMessagePromptTemplate.from_template({user_prompt})
+ chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
+ chain = LLMChain(llm=llm, prompt=chat_prompt)
+ response = chain.run(user_prompt=user_prompt)
+ source_ids = response.strip()
+
+ if source_ids.lower() == "unknown":
+ enhanced_sentences.append(sentence)
+ else:
+ # Extract just the IDs
+ source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
+ source_ids = re.sub(r'\s+', '', source_ids)
+ ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
+
+ if ids:
+ ref_parts = []
+ for id_ in ids:
+ if id_ in url_map:
+ ref_parts.append(f'{id_}')
+ else:
+ ref_parts.append(id_)
+ ref_string = ", ".join(ref_parts)
+ enhanced_sentences.append(f"{sentence} [{ref_string}]")
+ else:
+ enhanced_sentences.append(sentence)
+ except Exception:
+ enhanced_sentences.append(sentence)
+
+ enhanced_paragraphs.append(' '.join(enhanced_sentences))
+
+ # Join paragraphs back together with double newlines to preserve formatting
+ return '\n\n'.join(enhanced_paragraphs)
+
+
+st.sidebar.image("static/SNAP_logo.png", width=350)
+
+###############################################################################
+# Device / GPU Info
+###############################################################################
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+if device == 'cuda':
+ st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
+else:
+ st.sidebar.info("Using CPU")
+
+###############################################################################
+# Load or Compute Embeddings
+###############################################################################
+@st.cache_resource
+def get_embedding_model():
+ model_dir = get_model_dir()
+ st_model_dir = os.path.join(model_dir, 'sentence_transformer')
+ os.makedirs(st_model_dir, exist_ok=True)
+
+ model_name = 'all-MiniLM-L6-v2'
+ try:
+ # Try to load from local directory first
+ model = SentenceTransformer(st_model_dir)
+ #st.success("Loaded sentence transformer from local storage")
+ except Exception as e:
+ #st.warning("Downloading sentence transformer model (one-time operation)...")
+ try:
+ # Download and save to local directory
+ model = SentenceTransformer(model_name)
+ model.save(st_model_dir)
+ #st.success("Downloaded and saved sentence transformer model")
+ except Exception as download_e:
+ st.error(f"Error downloading sentence transformer model: {str(download_e)}")
+ raise
+
+ return model.to(device)
+
+def generate_embeddings(texts, model):
+ with st.spinner('Calculating embeddings...'):
+ embeddings = model.encode(texts, show_progress_bar=True, device=device)
+ return embeddings
+
+@st.cache_data
+def load_default_dataset(default_dataset_path):
+ if os.path.exists(default_dataset_path):
+ df_ = pd.read_excel(default_dataset_path)
+ return df_
+ else:
+ st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.")
+ return None
+
+@st.cache_data
+def load_uploaded_dataset(uploaded_file):
+ df_ = pd.read_excel(uploaded_file)
+ return df_
+
+def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None):
+ """
+ Loads pre-computed embeddings from a pickle file if they match current data,
+ otherwise computes and caches them.
+ """
+ if not text_columns:
+ return None, None
+
+ base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset"
+ if uploaded_file_name:
+ base_name = os.path.splitext(uploaded_file_name)[0]
+
+ cols_key = "_".join(sorted(text_columns))
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+ embeddings_dir = BASE_DIR
+ if using_default_dataset:
+ embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl')
+ else:
+ # For custom dataset, we still try to avoid regenerating each time
+ embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl")
+
+ df_fill = df.fillna("")
+ texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist()
+
+ # If already in session_state with matching columns and length, reuse
+ if ('embeddings' in st.session_state
+ and 'last_text_columns' in st.session_state
+ and st.session_state['last_text_columns'] == text_columns
+ and len(st.session_state['embeddings']) == len(texts)):
+ return st.session_state['embeddings'], st.session_state.get('embeddings_file', None)
+
+ # Try to load from disk
+ if os.path.exists(embeddings_file):
+ with open(embeddings_file, 'rb') as f:
+ embeddings = pickle.load(f)
+ if len(embeddings) == len(texts):
+ st.write("Loaded pre-calculated embeddings.")
+ st.session_state['embeddings'] = embeddings
+ st.session_state['embeddings_file'] = embeddings_file
+ st.session_state['last_text_columns'] = text_columns
+ return embeddings, embeddings_file
+
+ # Otherwise compute
+ st.write("Generating embeddings...")
+ model = get_embedding_model()
+ embeddings = generate_embeddings(texts, model)
+ with open(embeddings_file, 'wb') as f:
+ pickle.dump(embeddings, f)
+
+ st.session_state['embeddings'] = embeddings
+ st.session_state['embeddings_file'] = embeddings_file
+ st.session_state['last_text_columns'] = text_columns
+ return embeddings, embeddings_file
+
+
+###############################################################################
+# Reset Filter Function
+###############################################################################
+def reset_filters():
+ st.session_state['selected_additional_filters'] = {}
+
+# Selector de vista
+st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view")
+
+if st.session_state.view == "Power User Mode":
+ st.header("Power User Mode")
+ ###############################################################################
+ # Sidebar: Dataset Selection
+ ###############################################################################
+ st.sidebar.title("Data Selection")
+ dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset'))
+
+ if 'df' not in st.session_state:
+ st.session_state['df'] = pd.DataFrame()
+ if 'filtered_df' not in st.session_state:
+ st.session_state['filtered_df'] = pd.DataFrame()
+
+ if dataset_option == 'PRMS 2022+2023+2024 QAed':
+ default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
+ df = load_default_dataset(default_dataset_path)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = True
+
+ # Initialize filtered_df with full dataset by default
+ if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty:
+ st.session_state['filtered_df'] = df.copy()
+
+ # Initialize filter_state if not exists
+ if 'filter_state' not in st.session_state:
+ st.session_state['filter_state'] = {
+ 'applied': False,
+ 'filters': {}
+ }
+
+ # Set default text columns if not already set
+ if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
+ default_text_cols = []
+ if 'Title' in df.columns and 'Description' in df.columns:
+ default_text_cols = ['Title', 'Description']
+ st.session_state['text_columns'] = default_text_cols
+
+ #st.write("Using default dataset:")
+ #st.write("Data Preview:")
+ #st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
+ #st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+ df_cols = df.columns.tolist()
+
+ # Additional filter columns
+ st.subheader("Select Filters")
+ if 'additional_filters_selected' not in st.session_state:
+ st.session_state['additional_filters_selected'] = []
+ if 'filter_values' not in st.session_state:
+ st.session_state['filter_values'] = {}
+
+ with st.form("filter_selection_form"):
+ all_columns = df.columns.tolist()
+ selected_additional_cols = st.multiselect(
+ "Select columns from your dataset to use as filters:",
+ all_columns,
+ default=st.session_state['additional_filters_selected']
+ )
+ add_filters_submitted = st.form_submit_button("Add Additional Filters")
+
+ if add_filters_submitted:
+ if selected_additional_cols != st.session_state['additional_filters_selected']:
+ st.session_state['additional_filters_selected'] = selected_additional_cols
+ # Reset removed columns
+ st.session_state['filter_values'] = {
+ k: v for k, v in st.session_state['filter_values'].items()
+ if k in selected_additional_cols
+ }
+
+ # Show dynamic filters form if any selected columns
+ if st.session_state['additional_filters_selected']:
+ st.subheader("Apply Filters")
+
+ # Quick search section (outside form)
+ for col_name in st.session_state['additional_filters_selected']:
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+
+ # Add a search box for quick selection
+ search_key = f"search_{col_name}"
+ if search_key not in st.session_state:
+ st.session_state[search_key] = ""
+
+ col1, col2 = st.columns([3, 1])
+ with col1:
+ search_term = st.text_input(
+ f"Search in {col_name}",
+ key=search_key,
+ help="Enter text to find and select all matching values"
+ )
+ with col2:
+ if st.button(f"Select Matching", key=f"select_{col_name}"):
+ # Handle comma-separated values
+ if search_term:
+ matching_vals = [
+ val for val in unique_vals
+ if any(search_term.lower() in str(part).lower()
+ for part in (val.split(',') if isinstance(val, str) else [val]))
+ ]
+ # Update the multiselect default value
+ current_selected = st.session_state['filter_values'].get(col_name, [])
+ st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals))
+
+ # Show feedback about matches
+ if matching_vals:
+ st.success(f"Found and selected {len(matching_vals)} matching values")
+ else:
+ st.warning("No matching values found")
+
+ # Filter application form
+ with st.form("apply_filters_form"):
+ for col_name in st.session_state['additional_filters_selected']:
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+ selected_vals = st.multiselect(
+ f"Filter by {col_name}",
+ options=unique_vals,
+ default=st.session_state['filter_values'].get(col_name, [])
+ )
+ st.session_state['filter_values'][col_name] = selected_vals
+
+ # Add clear filters button and apply filters button
+ col1, col2 = st.columns([1, 4])
+ with col1:
+ clear_filters = st.form_submit_button("Clear All")
+ with col2:
+ apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset")
+
+ if clear_filters:
+ st.session_state['filter_values'] = {}
+ # Clear any existing summary data when filters are cleared
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+ st.rerun()
+
+ # Text columns selection moved to Advanced Settings
+ with st.expander("⚙️ Advanced Settings", expanded=False):
+ st.subheader("**Select Text Columns for Embedding**")
+ text_columns_selected = st.multiselect(
+ "Text Columns:",
+ df_cols,
+ default=st.session_state['text_columns'],
+ help="Choose columns containing text for semantic search and clustering. "
+ "If multiple are selected, their text will be concatenated."
+ )
+ st.session_state['text_columns'] = text_columns_selected
+
+ # Apply filters to the dataset
+ filtered_df = df.copy()
+ if 'apply_filters_submitted' in locals() and apply_filters_submitted:
+ # Clear any existing summary data when new filters are applied
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ for col_name in st.session_state['additional_filters_selected']:
+ selected_vals = st.session_state['filter_values'].get(col_name, [])
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+ st.success("Filters applied successfully!")
+ st.session_state['filtered_df'] = filtered_df.copy()
+ st.session_state['filter_state'] = {
+ 'applied': True,
+ 'filters': st.session_state['filter_values'].copy()
+ }
+ # Reset any existing clustering results
+ for k in ['clustered_data', 'topic_model', 'current_clustering_data',
+ 'current_clustering_option', 'hierarchy']:
+ if k in st.session_state:
+ del st.session_state[k]
+
+ elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']:
+ # Reapply stored filters
+ for col_name, selected_vals in st.session_state['filter_state']['filters'].items():
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+ st.session_state['filtered_df'] = filtered_df.copy()
+
+ # Show current data preview and download button
+ if st.session_state['filtered_df'] is not None:
+ if st.session_state['filter_state']['applied']:
+ st.write("Filtered Data Preview:")
+ else:
+ st.write("Current Data Preview:")
+ st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
+ st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ st.session_state['filtered_df'].to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+
+ st.download_button(
+ label="Download Current Data",
+ data=processed_data,
+ file_name='data.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ )
+ else:
+ st.warning("Please ensure the default dataset exists in the 'input' directory.")
+
+ else:
+ # Upload custom dataset
+ uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"])
+ if uploaded_file is not None:
+ df = load_uploaded_dataset(uploaded_file)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = False
+ st.session_state['uploaded_file_name'] = uploaded_file.name
+ st.write("Data preview:")
+ st.write(df.head())
+ df_cols = df.columns.tolist()
+
+ st.subheader("**Select Text Columns for Embedding**")
+ text_columns_selected = st.multiselect(
+ "Text Columns:",
+ df_cols,
+ default=df_cols[:1] if df_cols else []
+ )
+ st.session_state['text_columns'] = text_columns_selected
+
+ st.write("**Additional Filters**")
+ selected_additional_cols = st.multiselect(
+ "Select additional columns from your dataset to use as filters:",
+ df_cols,
+ default=[]
+ )
+ st.session_state['additional_filters_selected'] = selected_additional_cols
+
+ filtered_df = df.copy()
+ for col_name in selected_additional_cols:
+ if f'selected_filter_{col_name}' not in st.session_state:
+ st.session_state[f'selected_filter_{col_name}'] = []
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+ selected_vals = st.multiselect(
+ f"Filter by {col_name}",
+ options=unique_vals,
+ default=st.session_state[f'selected_filter_{col_name}']
+ )
+ st.session_state[f'selected_filter_{col_name}'] = selected_vals
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+
+ st.session_state['filtered_df'] = filtered_df
+ st.write("Filtered Data Preview:")
+ st.dataframe(filtered_df.head(), hide_index=True)
+ st.write(f"Total number of results: {len(filtered_df)}")
+
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ filtered_df.to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+
+ st.download_button(
+ label="Download Filtered Data",
+ data=processed_data,
+ file_name='filtered_data.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ )
+ else:
+ st.warning("Failed to load the uploaded dataset.")
+ else:
+ st.warning("Please upload an Excel file to proceed.")
+
+ if 'filtered_df' in st.session_state:
+ st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+
+ ###############################################################################
+ # Preserve active tab across reruns
+ ###############################################################################
+ if 'active_tab_index' not in st.session_state:
+ st.session_state.active_tab_index = 0
+
+ tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"]
+ tabs = st.tabs(tabs_titles)
+ # We just create these references so we can navigate more easily
+ tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs
+
+ ###############################################################################
+ # Tab: Help
+ ###############################################################################
+ with tab_help:
+ st.header("Help")
+ st.markdown("""
+ ### About SNAP
+
+ SNAP allows you to explore, filter, search, cluster, and summarize textual datasets.
+
+ **Workflow**:
+ 1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own.
+ 2. **Filtering**: Set additional filters for your dataset.
+ 3. **Select Text Columns**: Which columns to embed.
+ 4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents.
+ 5. **Clustering** (Tab): Group documents into topics.
+ 6. **Summarization** (Tab): Summarize the clustered documents (with optional references).
+
+ ### Troubleshooting
+ - If you see no results, try lowering the similarity threshold or removing negative/required keywords.
+ - Ensure you have at least one text column selected for embeddings.
+ """)
+
+ ###############################################################################
+ # Tab: Semantic Search
+ ###############################################################################
+ with tab_semantic:
+ st.header("Semantic Search")
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select at least one column for text embedding.")
+ else:
+ df_full = st.session_state['df']
+ # Load or compute embeddings if necessary
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if embeddings is not None:
+ with st.expander("ℹ️ How Semantic Search Works", expanded=False):
+ st.markdown("""
+ ### Understanding Semantic Search
+
+ Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works:
+
+ 1. **Query Processing**:
+ - Your search query is converted into a numerical representation (embedding) that captures its meaning
+ - Example: Searching for "Climate Smart Villages" will understand the concept, not just the words
+ - Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words
+
+ 2. **Similarity Matching**:
+ - Documents are ranked by how closely their meaning matches your query
+ - The similarity threshold controls how strict this matching is
+ - Higher threshold (e.g., 0.8) = more precise but fewer results
+ - Lower threshold (e.g., 0.3) = more results but might be less relevant
+
+ 3. **Advanced Features**:
+ - **Negative Keywords**: Use to explicitly exclude documents containing certain terms
+ - **Required Keywords**: Ensure specific terms appear in the results
+ - These work as traditional keyword filters after the semantic search
+
+ ### Search Tips
+
+ - **Phrase Queries**: Enter complete phrases for better context
+ - "Climate Smart Villages" (as one concept)
+ - Better than separate terms: "climate", "smart", "villages"
+
+ - **Descriptive Queries**: Add context for better results
+ - Instead of: "water"
+ - Better: "water management in agriculture"
+
+ - **Conceptual Queries**: Focus on concepts rather than specific terms
+ - Instead of: "increased yield"
+ - Better: "agricultural productivity improvements"
+
+ ### Example Searches
+
+ 1. **Query**: "Climate Smart Villages"
+ - Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development
+ - Even if they don't use these exact words
+
+ 2. **Query**: "Gender equality in agriculture"
+ - Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development
+ - Related concepts are captured semantically
+
+ 3. **Query**: "Sustainable water management"
+ + Required keyword: "irrigation"
+ - Combines semantic understanding of water sustainability with specific irrigation focus
+ """)
+
+ with st.form("search_parameters"):
+ query = st.text_input("Enter your search query:")
+ include_keywords = st.text_input("Include only documents containing these words (comma-separated):")
+ similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35)
+ submitted = st.form_submit_button("Search")
+
+ if submitted:
+ if query.strip():
+ with st.spinner("Performing Semantic Search..."):
+ # Clear any existing summary data when new search is run
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ model = get_embedding_model()
+ df_filtered = st.session_state['filtered_df'].fillna("")
+ search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
+
+ # Filter the embeddings to the same subset
+ subset_indices = df_filtered.index
+ subset_embeddings = embeddings[subset_indices]
+
+ query_embedding = model.encode([query], device=device)
+ similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
+
+ # Show distribution
+ fig = px.histogram(
+ x=similarities,
+ nbins=30,
+ labels={'x': 'Similarity Score', 'y': 'Number of Documents'},
+ title='Distribution of Similarity Scores'
+ )
+ fig.add_vline(
+ x=similarity_threshold,
+ line_dash="dash",
+ line_color="red",
+ annotation_text=f"Threshold: {similarity_threshold:.2f}",
+ annotation_position="top"
+ )
+ st.write("### Similarity Score Distribution")
+ st.plotly_chart(fig)
+
+ above_threshold_indices = np.where(similarities > similarity_threshold)[0]
+ if len(above_threshold_indices) == 0:
+ st.warning("No results found above the similarity threshold.")
+ if 'search_results' in st.session_state:
+ del st.session_state['search_results']
+ else:
+ selected_indices = subset_indices[above_threshold_indices]
+ results = df_filtered.loc[selected_indices].copy()
+ results['similarity_score'] = similarities[above_threshold_indices]
+ results.sort_values(by='similarity_score', ascending=False, inplace=True)
+
+ # Include keyword filtering
+ if include_keywords.strip():
+ inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()]
+ if inc_words:
+ results = results[
+ results.apply(
+ lambda row: all(
+ w in (' '.join(row.astype(str)).lower()) for w in inc_words
+ ),
+ axis=1
+ )
+ ]
+
+ if results.empty:
+ st.warning("No results found after applying keyword filters.")
+ if 'search_results' in st.session_state:
+ del st.session_state['search_results']
+ else:
+ st.session_state['search_results'] = results.copy()
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ results.to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+ st.session_state['search_results_processed_data'] = processed_data
+ else:
+ st.warning("Please enter a query to search.")
+
+ # Display search results if available
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ st.write("## Search Results")
+ results = st.session_state['search_results']
+ cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score']
+ st.dataframe(results[cols_to_display], hide_index=True)
+ st.write(f"Total number of results: {len(results)}")
+
+ if 'search_results_processed_data' in st.session_state:
+ st.download_button(
+ label="Download Full Results",
+ data=st.session_state['search_results_processed_data'],
+ file_name='search_results.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ key='download_search_results'
+ )
+ else:
+ st.info("No search results to display. Enter a query and click 'Search'.")
+ else:
+ st.warning("No embeddings available because no text columns were chosen.")
+ else:
+ st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.")
+
+
+ ###############################################################################
+ # Tab: Clustering
+ ###############################################################################
+ with tab_clustering:
+ st.header("Clustering")
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ # Add explanation about clustering
+ with st.expander("ℹ️ How Clustering Works", expanded=False):
+ st.markdown("""
+ ### Understanding Document Clustering
+
+ Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works:
+
+ 1. **Cluster Formation**:
+ - Documents are grouped based on their semantic similarity
+ - Each cluster represents a distinct theme or topic
+ - Documents that are too different from others may remain unclustered (labeled as -1)
+ - The "Min Cluster Size" parameter controls how clusters are formed
+
+ 2. **Interpreting Results**:
+ - Each cluster is assigned a number (e.g., 0, 1, 2...)
+ - Cluster -1 contains "outlier" documents that didn't fit well in other clusters
+ - The size of each cluster indicates how common that theme is
+ - Keywords for each cluster show the main topics/concepts
+
+ 3. **Visualizations**:
+ - **Intertopic Distance Map**: Shows how clusters relate to each other
+ - Closer clusters are more semantically similar
+ - Size of circles indicates number of documents
+ - Hover to see top terms for each cluster
+
+ - **Topic Document Visualization**: Shows individual documents
+ - Each point is a document
+ - Colors indicate cluster membership
+ - Distance between points shows similarity
+
+ - **Topic Hierarchy**: Shows how topics are related
+ - Tree structure shows topic relationships
+ - Parent topics contain broader themes
+ - Child topics show more specific sub-themes
+
+ ### How to Use Clusters
+
+ 1. **Exploration**:
+ - Use clusters to discover main themes in your data
+ - Look for unexpected groupings that might reveal insights
+ - Identify outliers that might need special attention
+
+ 2. **Analysis**:
+ - Compare cluster sizes to understand theme distribution
+ - Examine keywords to understand what defines each cluster
+ - Use hierarchy to see how themes are nested
+
+ 3. **Practical Applications**:
+ - Generate summaries for specific clusters
+ - Focus detailed analysis on clusters of interest
+ - Use clusters to organize and categorize documents
+ - Identify gaps or overlaps in your dataset
+
+ ### Tips for Better Results
+
+ - **Adjust Min Cluster Size**:
+ - Larger values (15-20): Fewer, broader clusters
+ - Smaller values (2-5): More specific, smaller clusters
+ - Balance between too many small clusters and too few large ones
+
+ - **Choose Data Wisely**:
+ - Cluster full dataset for overall themes
+ - Cluster search results for focused analysis
+ - More documents generally give better clusters
+
+ - **Interpret with Context**:
+ - Consider your domain knowledge
+ - Look for patterns across multiple visualizations
+ - Use cluster insights to guide further analysis
+ """)
+
+ df_to_cluster = None
+
+ # Create a single form for clustering settings
+ with st.form("clustering_form"):
+ st.subheader("Clustering Settings")
+
+ # Data source selection
+ clustering_option = st.radio(
+ "Select data for clustering:",
+ ('Full Dataset', 'Filtered Dataset', 'Semantic Search Results')
+ )
+
+ # Clustering parameters
+ min_cluster_size_val = st.slider(
+ "Min Cluster Size",
+ min_value=2,
+ max_value=50,
+ value=st.session_state.get('min_cluster_size', 5),
+ help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)"
+ )
+
+ run_clustering = st.form_submit_button("Run Clustering")
+
+ if run_clustering:
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.session_state['min_cluster_size'] = min_cluster_size_val
+
+ # Decide which DataFrame is used based on the selection
+ if clustering_option == 'Semantic Search Results':
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ df_to_cluster = st.session_state['search_results'].copy()
+ else:
+ st.warning("No semantic search results found. Please run a search first.")
+ elif clustering_option == 'Filtered Dataset':
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ df_to_cluster = st.session_state['filtered_df'].copy()
+ else:
+ st.warning("Filtered dataset is empty. Please check your filters.")
+ else: # Full Dataset
+ if 'df' in st.session_state and not st.session_state['df'].empty:
+ df_to_cluster = st.session_state['df'].copy()
+
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select text columns to embed before clustering.")
+ else:
+ # Ensure embeddings are available
+ df_full = st.session_state['df']
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering:
+ with st.spinner("Performing clustering..."):
+ # Clear any existing summary data when clustering is run
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ dfc = df_to_cluster.copy().fillna("")
+ dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
+
+ # Filter embeddings to those rows
+ selected_indices = dfc.index
+ embeddings_clustering = embeddings[selected_indices]
+
+ # Basic cleaning
+ stop_words = set(stopwords.words('english'))
+ texts_cleaned = []
+ for text in dfc['text'].tolist():
+ try:
+ # First try with word_tokenize
+ try:
+ word_tokens = word_tokenize(text)
+ except LookupError:
+ # If punkt is missing, try downloading it again
+ nltk.download('punkt_tab', quiet=False)
+ word_tokens = word_tokenize(text)
+ except Exception as e:
+ # If word_tokenize fails, fall back to simple splitting
+ st.warning(f"Using fallback tokenization due to error: {e}")
+ word_tokens = text.split()
+
+ filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
+ texts_cleaned.append(filtered_text)
+ except Exception as e:
+ st.error(f"Error processing text: {e}")
+ # Add the original text if processing fails
+ texts_cleaned.append(text)
+
+ try:
+ # Validation checks before clustering
+ if len(texts_cleaned) < min_cluster_size_val:
+ st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.")
+ st.session_state['clustering_error'] = "Insufficient documents for clustering"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Convert embeddings to CPU numpy if needed
+ if torch.is_tensor(embeddings_clustering):
+ embeddings_for_clustering = embeddings_clustering.cpu().numpy()
+ else:
+ embeddings_for_clustering = embeddings_clustering
+
+ # Additional validation
+ if embeddings_for_clustering.shape[0] != len(texts_cleaned):
+ st.error("Mismatch between number of embeddings and texts.")
+ st.session_state['clustering_error'] = "Embedding and text count mismatch"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Build the HDBSCAN model with error handling
+ try:
+ hdbscan_model = HDBSCAN(
+ min_cluster_size=min_cluster_size_val,
+ metric='euclidean',
+ cluster_selection_method='eom'
+ )
+
+ # Build the BERTopic model
+ topic_model = BERTopic(
+ embedding_model=get_embedding_model(),
+ hdbscan_model=hdbscan_model
+ )
+
+ # Fit the model and get topics
+ topics, probs = topic_model.fit_transform(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+
+ # Validate clustering results
+ unique_topics = set(topics)
+ if len(unique_topics) < 2:
+ st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.")
+ if -1 in unique_topics:
+ non_noise_docs = sum(1 for t in topics if t != -1)
+ st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).")
+ if non_noise_docs < min_cluster_size_val:
+ st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.")
+ st.session_state['clustering_error'] = "Insufficient clustered documents"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Store results if validation passes
+ dfc['Topic'] = topics
+ st.session_state['topic_model'] = topic_model
+ st.session_state['clustered_data'] = dfc.copy()
+ st.session_state['clustering_texts_cleaned'] = texts_cleaned
+ st.session_state['clustering_embeddings'] = embeddings_for_clustering
+ st.session_state['clustering_completed'] = True
+
+ # Try to generate visualizations with error handling
+ try:
+ st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
+ except Exception as viz_error:
+ st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.")
+ st.session_state['intertopic_distance_fig'] = None
+
+ try:
+ st.session_state['topic_document_fig'] = topic_model.visualize_documents(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+ except Exception as viz_error:
+ st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.")
+ st.session_state['topic_document_fig'] = None
+
+ try:
+ hierarchy = topic_model.hierarchical_topics(texts_cleaned)
+ st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
+ st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
+ except Exception as viz_error:
+ st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.")
+ st.session_state['hierarchy'] = pd.DataFrame()
+ st.session_state['hierarchy_fig'] = None
+
+ except ValueError as ve:
+ if "zero-size array to reduction operation maximum which has no identity" in str(ve):
+ st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.")
+ elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve):
+ st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.")
+ else:
+ st.error(f"Clustering error: {str(ve)}")
+ st.session_state['clustering_error'] = str(ve)
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ except Exception as e:
+ st.error(f"An error occurred during clustering: {str(e)}")
+ st.session_state['clustering_error'] = str(e)
+ st.session_state['clustering_completed'] = False
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Display clustering results if they exist
+ if st.session_state.get('clustering_completed', False):
+ st.subheader("Topic Overview")
+ dfc = st.session_state['clustered_data']
+ topic_model = st.session_state['topic_model']
+ topics = dfc['Topic'].tolist()
+
+ unique_topics = sorted(list(set(topics)))
+ cluster_info = []
+ for t in unique_topics:
+ cluster_docs = dfc[dfc['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ if top_words:
+ top_keywords = ", ".join([w[0] for w in top_words[:5]])
+ else:
+ top_keywords = "N/A"
+ cluster_info.append((t, count, top_keywords))
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+
+ st.write("### Topic Overview")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ st.subheader("Clustering Results")
+ columns_to_display = [c for c in dfc.columns if c != 'text']
+ st.dataframe(dfc[columns_to_display], hide_index=True)
+
+ # Display stored visualizations with error handling
+ st.write("### Intertopic Distance Map")
+ if st.session_state.get('intertopic_distance_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['intertopic_distance_fig'])
+ except Exception:
+ st.info("Topic visualization is not available for the current clustering results.")
+
+ st.write("### Topic Document Visualization")
+ if st.session_state.get('topic_document_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['topic_document_fig'])
+ except Exception:
+ st.info("Document visualization is not available for the current clustering results.")
+
+ st.write("### Topic Hierarchy")
+ if st.session_state.get('hierarchy_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['hierarchy_fig'])
+ except Exception:
+ st.info("Topic hierarchy visualization is not available for the current clustering results.")
+ if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering):
+ pass
+ else:
+ st.warning("Please select or upload a dataset and filter as needed.")
+
+
+ ###############################################################################
+ # Tab: Summarization
+ ###############################################################################
+ with tab_summarization:
+ st.header("Summarization")
+ # Add explanation about summarization
+ with st.expander("ℹ️ How Summarization Works", expanded=False):
+ st.markdown("""
+ ### Understanding Document Summarization
+
+ Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works:
+
+ 1. **Summary Generation**:
+ - Documents are processed using advanced language models
+ - Key themes and important points are identified
+ - Content is condensed while maintaining context
+ - Both high-level and cluster-specific summaries are available
+
+ 2. **Reference System**:
+ - Summaries can include references to source documents
+ - References are shown as [ID] or as clickable links
+ - Each statement can be traced back to its source
+ - Helps maintain accountability and verification
+
+ 3. **Types of Summaries**:
+ - **High-Level Summary**: Overview of all selected documents
+ - Captures main themes across the entire selection
+ - Ideal for quick understanding of large document sets
+ - Shows relationships between different topics
+
+ - **Cluster-Specific Summaries**: Focused on each cluster
+ - More detailed for specific themes
+ - Shows unique aspects of each cluster
+ - Helps understand sub-topics in depth
+
+ ### How to Use Summaries
+
+ 1. **Configuration**:
+ - Choose between all clusters or specific ones
+ - Set temperature for creativity vs. consistency
+ - Adjust max tokens for summary length
+ - Enable/disable reference system
+
+ 2. **Reference Options**:
+ - Select column for reference IDs
+ - Add hyperlinks to references
+ - Choose URL column for clickable links
+ - References help track information sources
+
+ 3. **Practical Applications**:
+ - Quick overview of large datasets
+ - Detailed analysis of specific themes
+ - Evidence-based reporting with references
+ - Compare different document groups
+
+ ### Tips for Better Results
+
+ - **Temperature Setting**:
+ - Higher (0.7-1.0): More creative, varied summaries
+ - Lower (0.1-0.3): More consistent, conservative summaries
+ - Balance based on your needs for creativity vs. consistency
+
+ - **Token Length**:
+ - Longer limits: More detailed summaries
+ - Shorter limits: More concise, focused summaries
+ - Adjust based on document complexity
+
+ - **Reference Usage**:
+ - Enable references for traceability
+ - Use hyperlinks for easy navigation
+ - Choose meaningful reference columns
+ - Helps validate summary accuracy
+
+ ### Best Practices
+
+ 1. **For General Overview**:
+ - Use high-level summary
+ - Keep temperature moderate (0.5-0.7)
+ - Enable references for verification
+ - Focus on broader themes
+
+ 2. **For Detailed Analysis**:
+ - Use cluster-specific summaries
+ - Adjust temperature based on need
+ - Include references with hyperlinks
+ - Look for patterns within clusters
+
+ 3. **For Reporting**:
+ - Combine both summary types
+ - Use references extensively
+ - Balance detail and brevity
+ - Ensure source traceability
+ """)
+
+ df_summ = None
+ # We'll try to summarize either the clustered data or just the filtered dataset
+ if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
+ df_summ = st.session_state['clustered_data']
+ elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ df_summ = st.session_state['filtered_df']
+ else:
+ st.warning("No data available for summarization. Please cluster first or have some filtered data.")
+
+ if df_summ is not None and not df_summ.empty:
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select columns for text embedding first.")
+ else:
+ if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state:
+ st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.")
+ else:
+ topic_model = st.session_state['topic_model']
+ df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1)
+
+ # List of topics
+ topics = sorted(df_summ['Topic'].unique())
+ cluster_info = []
+ for t in topics:
+ cluster_docs = df_summ[df_summ['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ if top_words:
+ top_keywords = ", ".join([w[0] for w in top_words[:5]])
+ else:
+ top_keywords = "N/A"
+ cluster_info.append((t, count, top_keywords))
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+
+ # If we have cluster names from previous summarization, add them
+ if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
+ summary_df = st.session_state['summary_df']
+ # Create a mapping of topic to name for merging
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ # Add cluster names to cluster_df
+ cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
+ # Reorder columns to show name after topic
+ cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
+
+ st.write("### Available Clusters:")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Summarization settings
+ st.subheader("Summarization Settings")
+ # Summaries scope
+ summary_scope = st.radio(
+ "Generate summaries for:",
+ ["All clusters", "Specific clusters"]
+ )
+ if summary_scope == "Specific clusters":
+ # Format options to include cluster names if available
+ if 'Cluster_Name' in cluster_df.columns:
+ topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])]
+ topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])}
+ selected_topic_options = st.multiselect("Select clusters to summarize", topic_options)
+ selected_topics = [topic_to_id[opt] for opt in selected_topic_options]
+ else:
+ selected_topics = st.multiselect("Select clusters to summarize", topics)
+ else:
+ selected_topics = topics
+
+ # Add system prompt configuration
+ default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries.
+ You will be given text and an objective context. Please produce a clear, cohesive,
+ and thematically relevant summary.
+ Focus on key points, insights, or patterns that emerge from the text."""
+
+ if 'system_prompt' not in st.session_state:
+ st.session_state['system_prompt'] = default_system_prompt
+
+ with st.expander("🔧 Advanced Settings", expanded=False):
+ st.markdown("""
+ ### System Prompt Configuration
+
+ The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs:
+ - Be specific about the style and focus you want
+ - Add domain-specific context if needed
+ - Include any special formatting requirements
+ """)
+
+ system_prompt = st.text_area(
+ "Customize System Prompt",
+ value=st.session_state['system_prompt'],
+ height=150,
+ help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus."
+ )
+
+ if st.button("Reset to Default"):
+ system_prompt = default_system_prompt
+ st.session_state['system_prompt'] = default_system_prompt
+
+ st.markdown("### Generation Parameters")
+ temperature = st.slider(
+ "Temperature",
+ 0.0, 1.0, 0.7,
+ help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent."
+ )
+ max_tokens = st.slider(
+ "Max Tokens",
+ 100, 3000, 1000,
+ help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate."
+ )
+
+ st.session_state['system_prompt'] = system_prompt
+
+ st.write("### Enhanced Summary References")
+ st.write("Select columns for references (optional).")
+ all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']]
+
+ # By default, let's guess the first column as reference ID if available
+ if 'reference_id_column' not in st.session_state:
+ st.session_state.reference_id_column = all_cols[0] if all_cols else None
+ # If there's a column that looks like a URL, guess that
+ url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None)
+ if 'url_column' not in st.session_state:
+ st.session_state.url_column = url_guess
+
+ enable_references = st.checkbox(
+ "Enable references in summaries",
+ value=True, # default to True as requested
+ help="Add source references to the final summary text."
+ )
+ reference_id_column = st.selectbox(
+ "Select column to use as reference ID:",
+ all_cols,
+ index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0
+ )
+ add_hyperlinks = st.checkbox(
+ "Add hyperlinks to references",
+ value=True, # default to True
+ help="If the reference column has a matching URL, make it clickable."
+ )
+ url_column = None
+ if add_hyperlinks:
+ url_column = st.selectbox(
+ "Select column containing URLs:",
+ all_cols,
+ index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0
+ )
+
+ # Summarization button
+ if st.button("Generate Summaries"):
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
+ if not openai_api_key:
+ st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
+ else:
+ # Set flag to indicate summarization button was clicked
+ st.session_state['_summarization_button_clicked'] = True
+
+ llm = ChatOpenAI(
+ api_key=openai_api_key,
+ model_name='gpt-4o-mini', # or 'gpt-4o'
+ temperature=temperature,
+ max_tokens=max_tokens
+ )
+
+ # Filter to selected topics
+ if selected_topics:
+ df_scope = df_summ[df_summ['Topic'].isin(selected_topics)]
+ else:
+ st.warning("No topics selected for summarization.")
+ df_scope = pd.DataFrame()
+
+ if df_scope.empty:
+ st.warning("No documents match the selected topics for summarization.")
+ else:
+ all_texts = df_scope['text'].tolist()
+ combined_text = " ".join(all_texts)
+ if not combined_text.strip():
+ st.warning("No text data available for summarization.")
+ else:
+ # For cluster-specific summaries, use the customized prompt
+ local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
+ local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
+
+ # Summaries per cluster
+ # Only if multiple clusters are selected
+ unique_selected_topics = df_scope['Topic'].unique()
+ if len(unique_selected_topics) > 1:
+ st.write("### Summaries per Selected Cluster")
+
+ # Process summaries in parallel
+ with st.spinner("Generating cluster summaries in parallel..."):
+ summaries = process_summaries_in_parallel(
+ df_scope=df_scope,
+ unique_selected_topics=unique_selected_topics,
+ llm=llm,
+ chat_prompt=local_chat_prompt,
+ enable_references=enable_references,
+ reference_id_column=reference_id_column,
+ url_column=url_column if add_hyperlinks else None,
+ max_workers=min(16, len(unique_selected_topics)) # Limit workers based on clusters
+ )
+
+ if summaries:
+ summary_df = pd.DataFrame(summaries)
+ # Store the summaries DataFrame in session state
+ st.session_state['summary_df'] = summary_df
+ # Store additional summary info in session state
+ st.session_state['has_references'] = enable_references
+ st.session_state['reference_id_column'] = reference_id_column
+ st.session_state['url_column'] = url_column if add_hyperlinks else None
+
+ # Update cluster_df with new names
+ if 'Cluster_Name' in summary_df.columns:
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
+ cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
+
+ # Immediately display updated cluster overview
+ st.write("### Updated Topic Overview:")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Now generate high-level summary from the cluster summaries
+ with st.spinner("Generating high-level summary from cluster summaries..."):
+ # Format cluster summaries with proper markdown and HTML
+ formatted_summaries = []
+ total_tokens = 0
+ MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) # Leave room for system prompt and completion
+ summary_batches = []
+ current_batch = []
+ current_batch_tokens = 0
+
+ for _, row in summary_df.iterrows():
+ summary_text = row.get('Enhanced_Summary', row['Summary'])
+ formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
+ summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
+
+ # If adding this summary would exceed the safe token limit, start a new batch
+ if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
+ if current_batch: # Only append if we have summaries in the current batch
+ summary_batches.append(current_batch)
+ current_batch = []
+ current_batch_tokens = 0
+
+ current_batch.append(formatted_summary)
+ current_batch_tokens += summary_tokens
+
+ # Add the last batch if it has any summaries
+ if current_batch:
+ summary_batches.append(current_batch)
+
+ # Generate overview for each batch
+ batch_overviews = []
+ with st.spinner("Generating batch summaries..."):
+ for i, batch in enumerate(summary_batches, 1):
+ st.write(f"Processing batch {i} of {len(summary_batches)}...")
+
+ batch_text = "\n\n".join(batch)
+ batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
+
+Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
+1. Preserve all hyperlinked references exactly as they appear in the input summaries
+2. Maintain the HTML anchor tags () intact when using information from the summaries
+3. Keep the markdown formatting for better readability
+4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters
+
+Here are the cluster summaries to synthesize:
+
+{batch_text}"""
+
+ # Generate overview for this batch
+ high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
+ high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message])
+ high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
+ batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
+ batch_overviews.append(batch_overview)
+
+ # Now combine the batch overviews
+ with st.spinner("Generating final combined summary..."):
+ combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents.
+
+Please create a final comprehensive synthesis that:
+1. Integrates the key themes and findings from all parts
+2. Preserves all hyperlinked references exactly as they appear
+3. Maintains the HTML anchor tags () intact
+4. Keeps the markdown formatting for better readability
+5. Creates a coherent narrative across all parts
+6. Highlights any themes that span multiple parts
+
+Here are the overviews to synthesize:
+
+### Part 1:
+
+{combined_overviews}"""
+
+ # Verify the final prompt's token count
+ final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
+ if final_prompt_tokens > MAX_SAFE_TOKENS:
+ st.error(f"❌ Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.")
+ high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ else:
+ # Generate final synthesis
+ high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
+ high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
+
+ # Store both versions of the summary
+ st.session_state['high_level_summary'] = high_level_summary
+ st.session_state['enhanced_summary'] = high_level_summary
+
+ # Set flag to indicate summarization is complete
+ st.session_state['summarization_completed'] = True
+
+ # Update the display without rerunning
+ st.write("### High-Level Summary:")
+ st.markdown(high_level_summary, unsafe_allow_html=True)
+
+ # Display cluster summaries
+ st.write("### Cluster Summaries:")
+ if enable_references and 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.write("### Summaries per Cluster:")
+ if 'Cluster_Name' in summary_df.columns:
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Download
+ if 'Enhanced_Summary' in summary_df.columns:
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ else:
+ dl_df = summary_df
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+
+ # Display existing summaries if available and summarization was completed
+ if st.session_state.get('summarization_completed', False):
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ if 'high_level_summary' in st.session_state:
+ st.write("### High-Level Summary:")
+ st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True)
+
+ st.write("### Cluster Summaries:")
+ summary_df = st.session_state['summary_df']
+ if 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Add download button for existing summaries
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
+ if 'Cluster_Name' in dl_df.columns:
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+ else:
+ st.warning("No data available for summarization.")
+
+ # Display existing summaries if available (when returning to the tab)
+ if not st.session_state.get('_summarization_button_clicked', False): # Only show if not just generated
+ if 'high_level_summary' in st.session_state:
+ st.write("### Existing High-Level Summary:")
+ if st.session_state.get('enhanced_summary'):
+ st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True)
+ with st.expander("View original summary (without references)"):
+ st.write(st.session_state['high_level_summary'])
+ else:
+ st.write(st.session_state['high_level_summary'])
+
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ st.write("### Existing Cluster Summaries:")
+ summary_df = st.session_state['summary_df']
+ if 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Add download button for existing summaries
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
+ if 'Cluster_Name' in dl_df.columns:
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+
+
+ ###############################################################################
+ # Tab: Chat
+ ###############################################################################
+ with tab_chat:
+ st.header("Chat with Your Data")
+
+ # Add explanation about chat functionality
+ with st.expander("ℹ️ How Chat Works", expanded=False):
+ st.markdown("""
+ ### Understanding Chat with Your Data
+
+ The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works:
+
+ 1. **Data Selection**:
+ - Choose which dataset to chat about (filtered, clustered, or search results)
+ - Optionally focus on specific clusters if clustering was performed
+ - System automatically includes relevant context from your selection
+
+ 2. **Context Window**:
+ - Shows how much of the GPT-4 context window is being used
+ - Helps you understand if you need to filter data further
+ - Displays token usage statistics
+
+ 3. **Chat Features**:
+ - Ask questions about your data
+ - Get insights and analysis
+ - Reference specific documents or clusters
+ - Download chat context for transparency
+
+ ### Best Practices
+
+ 1. **Data Selection**:
+ - Start with filtered or clustered data for more focused conversations
+ - Select specific clusters if you want to dive deep into a topic
+ - Consider the context window usage when selecting data
+
+ 2. **Asking Questions**:
+ - Be specific in your questions
+ - Ask about patterns, trends, or insights
+ - Reference clusters or documents by their IDs
+ - Build on previous questions for deeper analysis
+
+ 3. **Managing Context**:
+ - Monitor the context window usage
+ - Filter data further if context is too full
+ - Download chat context for documentation
+ - Clear chat history to start fresh
+
+ ### Tips for Better Results
+
+ - **Question Types**:
+ - "What are the main themes in cluster 3?"
+ - "Compare the findings between clusters 1 and 2"
+ - "Summarize the methodology used across these documents"
+ - "What are the common outcomes reported?"
+
+ - **Follow-up Questions**:
+ - Build on previous answers
+ - Ask for clarification
+ - Request specific examples
+ - Explore relationships between findings
+ """)
+
+ # Function to check data source availability
+ def get_available_data_sources():
+ sources = []
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ sources.append("Filtered Dataset")
+ if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
+ sources.append("Clustered Data")
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ sources.append("Search Results")
+ if ('high_level_summary' in st.session_state or
+ ('summary_df' in st.session_state and not st.session_state['summary_df'].empty)):
+ sources.append("Summarized Data")
+ return sources
+
+ # Get available data sources
+ available_sources = get_available_data_sources()
+
+ if not available_sources:
+ st.warning("No data available for chat. Please filter, cluster, search, or summarize first.")
+ st.stop()
+
+ # Initialize or update data source in session state
+ if 'chat_data_source' not in st.session_state:
+ st.session_state.chat_data_source = available_sources[0]
+ elif st.session_state.chat_data_source not in available_sources:
+ st.session_state.chat_data_source = available_sources[0]
+
+ # Data source selection with automatic fallback
+ data_source = st.radio(
+ "Select data to chat about:",
+ available_sources,
+ index=available_sources.index(st.session_state.chat_data_source),
+ help="Choose which dataset you want to analyze in the chat."
+ )
+
+ # Update session state if data source changed
+ if data_source != st.session_state.chat_data_source:
+ st.session_state.chat_data_source = data_source
+ # Clear any cluster-specific selections if switching data sources
+ if 'chat_selected_cluster' in st.session_state:
+ del st.session_state.chat_selected_cluster
+
+ # Get the appropriate DataFrame based on selected source
+ df_chat = None
+ if data_source == "Filtered Dataset":
+ df_chat = st.session_state['filtered_df']
+ elif data_source == "Clustered Data":
+ df_chat = st.session_state['clustered_data']
+ elif data_source == "Search Results":
+ df_chat = st.session_state['search_results']
+ elif data_source == "Summarized Data":
+ # Create DataFrame with selected summaries
+ summary_rows = []
+
+ # Add high-level summary if available
+ if 'high_level_summary' in st.session_state:
+ summary_rows.append({
+ 'Summary_Type': 'High-Level Summary',
+ 'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary'])
+ })
+
+ # Add cluster summaries if available
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ summary_df = st.session_state['summary_df']
+ for _, row in summary_df.iterrows():
+ summary_rows.append({
+ 'Summary_Type': f"Cluster {row['Topic']} Summary",
+ 'Content': row.get('Enhanced_Summary', row['Summary'])
+ })
+
+ if summary_rows:
+ df_chat = pd.DataFrame(summary_rows)
+
+ if df_chat is not None and not df_chat.empty:
+ # If we have clustered data, allow cluster selection
+ selected_cluster = None
+ if data_source != "Summarized Data" and 'Topic' in df_chat.columns:
+ cluster_option = st.radio(
+ "Choose cluster scope:",
+ ["All Clusters", "Specific Cluster"]
+ )
+ if cluster_option == "Specific Cluster":
+ unique_topics = sorted(df_chat['Topic'].unique())
+ # Check if we have cluster names
+ if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
+ summary_df = st.session_state['summary_df']
+ # Create a mapping of topic to name
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ # Format the selectbox options
+ topic_options = [
+ (t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}")
+ for t in unique_topics
+ ]
+ selected_cluster = st.selectbox(
+ "Select cluster to focus on:",
+ [t[0] for t in topic_options],
+ format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x)
+ )
+ else:
+ selected_cluster = st.selectbox(
+ "Select cluster to focus on:",
+ unique_topics,
+ format_func=lambda x: f"Cluster {x}"
+ )
+ if selected_cluster is not None:
+ df_chat = df_chat[df_chat['Topic'] == selected_cluster]
+ st.session_state.chat_selected_cluster = selected_cluster
+ elif 'chat_selected_cluster' in st.session_state:
+ del st.session_state.chat_selected_cluster
+
+ # Prepare the data for chat context
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns and data_source != "Summarized Data":
+ st.warning("No text columns selected. Please select text columns to enable chat functionality.")
+ st.stop()
+
+ # Instead of limiting to 210 documents, we'll limit by tokens
+ MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95) # 95% of context window
+
+ # Prepare system message first to account for its tokens
+ system_msg = {
+ "role": "system",
+ "content": """You are a specialized assistant analyzing data from a research database.
+ Your role is to:
+ 1. Provide clear, concise answers based on the data provided
+ 2. Highlight relevant information from specific results when answering
+ 3. When referencing specific results, use their row index or ID if available
+ 4. Clearly state if information is not available in the results
+ 5. Maintain a professional and analytical tone
+ 6. Format your responses using Markdown:
+ - Use **bold** for emphasis
+ - Use bullet points and numbered lists for structured information
+ - Create tables using Markdown syntax when presenting structured data
+ - Use backticks for code or technical terms
+ - Include hyperlinks when referencing external sources
+ - Use headings (###) to organize long responses
+
+ The data is provided in a structured format where:""" + ("""
+ - Each result contains multiple fields
+ - Text content is primarily in the following columns: """ + ", ".join(text_columns) + """
+ - Additional metadata and fields are available for reference
+ - If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """
+ - The data consists of AI-generated summaries of the documents
+ - Each summary may contain references to source documents in markdown format
+ - References are shown as [ID] or as clickable hyperlinks
+ - Summaries may be high-level (covering all documents) or cluster-specific""") + """
+ """
+ }
+
+ # Calculate system message tokens
+ system_tokens = len(tokenizer(system_msg["content"])["input_ids"])
+ remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens
+
+ # Prepare the data context with token limiting
+ data_text = "Available Data:\n"
+ included_rows = 0
+ total_rows = len(df_chat)
+
+ if data_source == "Summarized Data":
+ # For summarized data, process row by row
+ for idx, row in df_chat.iterrows():
+ row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n"
+ row_tokens = len(tokenizer(row_text)["input_ids"])
+
+ if remaining_tokens - row_tokens > 0:
+ data_text += row_text
+ remaining_tokens -= row_tokens
+ included_rows += 1
+ else:
+ break
+ else:
+ # For regular data, process row by row
+ for idx, row in df_chat.iterrows():
+ row_text = f"\nItem {idx}:\n"
+ for col in df_chat.columns:
+ if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score':
+ row_text += f"{col}: {row[col]}\n"
+
+ row_tokens = len(tokenizer(row_text)["input_ids"])
+ if remaining_tokens - row_tokens > 0:
+ data_text += row_text
+ remaining_tokens -= row_tokens
+ included_rows += 1
+ else:
+ break
+
+ # Calculate token usage
+ data_tokens = len(tokenizer(data_text)["input_ids"])
+ total_tokens = system_tokens + data_tokens
+ context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100
+
+ # Display token usage and data coverage
+ st.subheader("Context Window Usage")
+ st.write(f"System Message: {system_tokens:,} tokens")
+ st.write(f"Data Context: {data_tokens:,} tokens")
+ st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)")
+ st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)")
+
+ if context_usage_percent > 90:
+ st.warning("⚠️ High context usage! Consider reducing the number of results or filtering further.")
+ elif context_usage_percent > 75:
+ st.info("ℹ️ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.")
+
+ # Add download button for chat context
+ chat_context = f"""System Message:
+ {system_msg['content']}
+
+ {data_text}"""
+ st.download_button(
+ label="📥 Download Chat Context",
+ data=chat_context,
+ file_name="chat_context.txt",
+ mime="text/plain",
+ help="Download the exact context that the chatbot receives"
+ )
+
+ # Chat interface
+ col_chat1, col_chat2 = st.columns([3, 1])
+ with col_chat1:
+ user_input = st.text_area("Ask a question about your data:", key="chat_input")
+ with col_chat2:
+ if st.button("Clear Chat History"):
+ st.session_state.chat_history = []
+ st.rerun()
+
+ # Store current tab index before processing
+ current_tab = tabs_titles.index("Chat")
+
+ if st.button("Send", key="send_button"):
+ if user_input:
+ # Set the active tab index to stay on Chat
+ st.session_state.active_tab_index = current_tab
+
+ with st.spinner("Processing your question..."):
+ # Add user's question to chat history
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
+
+ # Prepare messages for API call
+ messages = [system_msg]
+ messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"})
+
+ # Get response from OpenAI
+ response = get_chat_response(messages)
+
+ if response:
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
+
+ # Display chat history
+ st.subheader("Chat History")
+ for message in st.session_state.chat_history:
+ if message["role"] == "user":
+ st.write("**You:**", message["content"])
+ else:
+ st.write("**Assistant:**")
+ st.markdown(message["content"], unsafe_allow_html=True)
+ st.write("---") # Add a separator between messages
+
+
+ ###############################################################################
+ # Tab: Internal Validation
+ ###############################################################################
+
+else: # Simple view
+ st.header("Automatic Mode")
+
+ # Initialize session state for automatic view
+ if 'df' not in st.session_state:
+ default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
+ df = load_default_dataset(default_dataset_path)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = True
+ st.session_state['filtered_df'] = df.copy()
+
+ # Set default text columns if not already set
+ if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
+ default_text_cols = []
+ if 'Title' in df.columns and 'Description' in df.columns:
+ default_text_cols = ['Title', 'Description']
+ st.session_state['text_columns'] = default_text_cols
+
+ # Single search bar for automatic processing
+ #st.write("Enter your query to automatically search, cluster, and summarize the results:")
+ query = st.text_input("Write your query here:")
+
+
+
+
+ if st.button("SNAP!"):
+ if query.strip():
+ # Step 1: Semantic Search
+ st.write("### Step 1: Semantic Search")
+ with st.spinner("Performing Semantic Search..."):
+ text_columns = st.session_state.get('text_columns', [])
+ if text_columns:
+ df_full = st.session_state['df']
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if embeddings is not None:
+ model = get_embedding_model()
+ df_filtered = st.session_state['filtered_df'].fillna("")
+ search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
+
+ subset_indices = df_filtered.index
+ subset_embeddings = embeddings[subset_indices]
+
+ query_embedding = model.encode([query], device=device)
+ similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
+
+ similarity_threshold = 0.35 # Default threshold
+ above_threshold_indices = np.where(similarities > similarity_threshold)[0]
+
+ if len(above_threshold_indices) > 0:
+ selected_indices = subset_indices[above_threshold_indices]
+ results = df_filtered.loc[selected_indices].copy()
+ results['similarity_score'] = similarities[above_threshold_indices]
+ results.sort_values(by='similarity_score', ascending=False, inplace=True)
+ st.session_state['search_results'] = results.copy()
+ st.write(f"Found {len(results)} relevant documents")
+ else:
+ st.warning("No results found above the similarity threshold.")
+ st.stop()
+
+ # Step 2: Clustering
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ st.write("### Step 2: Clustering")
+ with st.spinner("Performing clustering..."):
+ df_to_cluster = st.session_state['search_results'].copy()
+ dfc = df_to_cluster.copy().fillna("")
+ dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
+
+ # Filter embeddings to those rows
+ selected_indices = dfc.index
+ embeddings_clustering = embeddings[selected_indices]
+
+ # Basic cleaning
+ stop_words = set(stopwords.words('english'))
+ texts_cleaned = []
+ for text in dfc['text'].tolist():
+ try:
+ word_tokens = word_tokenize(text)
+ filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
+ texts_cleaned.append(filtered_text)
+ except Exception as e:
+ texts_cleaned.append(text)
+
+ min_cluster_size = 5 # Default value
+
+ try:
+ # Convert embeddings to CPU numpy if needed
+ if torch.is_tensor(embeddings_clustering):
+ embeddings_for_clustering = embeddings_clustering.cpu().numpy()
+ else:
+ embeddings_for_clustering = embeddings_clustering
+
+ # Build the HDBSCAN model
+ hdbscan_model = HDBSCAN(
+ min_cluster_size=min_cluster_size,
+ metric='euclidean',
+ cluster_selection_method='eom'
+ )
+
+ # Build the BERTopic model
+ topic_model = BERTopic(
+ embedding_model=get_embedding_model(),
+ hdbscan_model=hdbscan_model
+ )
+
+ # Fit the model and get topics
+ topics, probs = topic_model.fit_transform(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+
+ # Store results
+ dfc['Topic'] = topics
+ st.session_state['topic_model'] = topic_model
+ st.session_state['clustered_data'] = dfc.copy()
+ st.session_state['clustering_completed'] = True
+
+ # Display clustering results summary
+ unique_topics = sorted(list(set(topics)))
+ num_clusters = len([t for t in unique_topics if t != -1]) # Exclude noise cluster (-1)
+ noise_docs = len([t for t in topics if t == -1])
+ clustered_docs = len(topics) - noise_docs
+
+ st.write(f"Found {num_clusters} distinct clusters")
+ #st.write(f"Documents successfully clustered: {clustered_docs}")
+ #if noise_docs > 0:
+ # st.write(f"Documents not fitting in any cluster: {noise_docs}")
+
+ # Show quick cluster overview
+ cluster_info = []
+ for t in unique_topics:
+ if t != -1: # Skip noise cluster in the overview
+ cluster_docs = dfc[dfc['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
+ cluster_info.append((t, count, top_keywords))
+
+ if cluster_info:
+ #st.write("### Quick Cluster Overview:")
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+ # st.dataframe(
+ # cluster_df,
+ # column_config={
+ # "Topic": st.column_config.NumberColumn("Topic", help="Topic ID"),
+ # "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ # "Top Keywords": st.column_config.TextColumn(
+ # "Top Keywords",
+ # help="Top 5 keywords that characterize this topic"
+ # )
+ # },
+ # hide_index=True
+ # )
+
+ # Generate visualizations
+ try:
+ st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
+ except Exception:
+ st.session_state['intertopic_distance_fig'] = None
+
+ try:
+ st.session_state['topic_document_fig'] = topic_model.visualize_documents(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+ except Exception:
+ st.session_state['topic_document_fig'] = None
+
+ try:
+ hierarchy = topic_model.hierarchical_topics(texts_cleaned)
+ st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
+ st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
+ except Exception:
+ st.session_state['hierarchy'] = pd.DataFrame()
+ st.session_state['hierarchy_fig'] = None
+
+ except Exception as e:
+ st.error(f"An error occurred during clustering: {str(e)}")
+ st.stop()
+
+ # Step 3: Summarization
+ if st.session_state.get('clustering_completed', False):
+ st.write("### Step 3: Summarization")
+
+ # Initialize OpenAI client
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
+ if not openai_api_key:
+ st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
+ st.stop()
+
+ llm = ChatOpenAI(
+ api_key=openai_api_key,
+ model_name='gpt-4o-mini',
+ temperature=0.7,
+ max_tokens=1000
+ )
+
+ df_scope = st.session_state['clustered_data']
+ unique_selected_topics = df_scope['Topic'].unique()
+
+ # Process summaries in parallel
+ with st.spinner("Generating summaries..."):
+ local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries.
+You will be given text and an objective context. Please produce a clear, cohesive,
+and thematically relevant summary.
+Focus on key points, insights, or patterns that emerge from the text.""")
+ local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
+
+ # Find URL column if it exists
+ url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None)
+
+ summaries = process_summaries_in_parallel(
+ df_scope=df_scope,
+ unique_selected_topics=unique_selected_topics,
+ llm=llm,
+ chat_prompt=local_chat_prompt,
+ enable_references=True,
+ reference_id_column=df_scope.columns[0],
+ url_column=url_column, # Add URL column for clickable links
+ max_workers=min(16, len(unique_selected_topics))
+ )
+
+ if summaries:
+ summary_df = pd.DataFrame(summaries)
+ st.session_state['summary_df'] = summary_df
+
+ # Display updated cluster overview
+ if 'Cluster_Name' in summary_df.columns:
+ st.write("### Updated Topic Overview:")
+ cluster_info = []
+ for t in unique_selected_topics:
+ cluster_docs = df_scope[df_scope['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
+ cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0]
+ cluster_info.append((t, cluster_name, count, top_keywords))
+
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"])
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Generate and display high-level summary
+ with st.spinner("Generating high-level summary..."):
+ formatted_summaries = []
+ summary_batches = []
+ current_batch = []
+ current_batch_tokens = 0
+ MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75)
+
+ for _, row in summary_df.iterrows():
+ summary_text = row.get('Enhanced_Summary', row['Summary'])
+ formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
+ summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
+
+ if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
+ if current_batch:
+ summary_batches.append(current_batch)
+ current_batch = []
+ current_batch_tokens = 0
+
+ current_batch.append(formatted_summary)
+ current_batch_tokens += summary_tokens
+
+ if current_batch:
+ summary_batches.append(current_batch)
+
+ # Process each batch separately first
+ batch_overviews = []
+ for i, batch in enumerate(summary_batches, 1):
+ st.write(f"Processing summary batch {i} of {len(summary_batches)}...")
+ batch_text = "\n\n".join(batch)
+ batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
+
+Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
+1. Preserve all hyperlinked references exactly as they appear in the input summaries
+2. Maintain the HTML anchor tags () intact when using information from the summaries
+3. Keep the markdown formatting for better readability
+4. Create clear sections with headings for different themes
+5. Use bullet points or numbered lists where appropriate
+6. Focus on synthesizing the main themes and findings
+
+Here are the cluster summaries to synthesize:
+
+{batch_text}"""
+
+ high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
+ batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
+ batch_overviews.append(batch_overview)
+
+ # Now create the final synthesis
+ if len(batch_overviews) > 1:
+ st.write("Generating final synthesis...")
+ combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents.
+
+Please create a final comprehensive synthesis that:
+1. Integrates the key themes and findings from all parts into a cohesive narrative
+2. Preserves all hyperlinked references exactly as they appear
+3. Maintains the HTML anchor tags () intact
+4. Uses clear section headings and structured formatting
+5. Highlights cross-cutting themes and relationships between different aspects
+6. Provides a clear introduction and conclusion
+
+Here are the overviews to synthesize:
+
+# Part 1
+
+{combined_overviews}"""
+
+ final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
+ if final_prompt_tokens > MAX_SAFE_TOKENS:
+ # If too long, just combine with headers
+ high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ else:
+ high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
+ high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
+ else:
+ # If only one batch, use its overview directly
+ high_level_summary = batch_overviews[0]
+
+ st.session_state['high_level_summary'] = high_level_summary
+ st.session_state['enhanced_summary'] = high_level_summary
+
+ # Display summaries
+ st.write("### High-Level Summary:")
+ with st.expander("High-Level Summary", expanded=True):
+ st.markdown(high_level_summary, unsafe_allow_html=True)
+
+ st.write("### Cluster Summaries:")
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False):
+ st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True)
+ st.markdown("##### About this tool")
+ with st.expander("Click to expand/collapse", expanded=True):
+ st.markdown("""
+ This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on.
+
+ **Tips:**
+ - **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`).
+ - Avoid writing full questions — **this is not a chatbot**.
+ - Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`).
+ - Focus on **concepts or themes** — not single words like `"climate"` or `"yield"` alone.
+ - Example good queries:
+ - `"climate adaptation smallholder farming"`
+ - `"digital agriculture innovations"`
+ - `"nutrition-sensitive value chains"`
+
+ **Example use case**:
+ You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**.
+ A good search phrase would be:
+ 👉 `"poverty reduction maize Africa"`
+ This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*.
""")
\ No newline at end of file