annikwag commited on
Commit
367acc4
verified
1 Parent(s): 540cd3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -67
app.py CHANGED
@@ -23,20 +23,10 @@ DEDICATED_ENDPOINT = "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingfa
23
  WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
24
 
25
  def get_rag_answer(query, top_results):
26
- """
27
- Constructs a prompt from the query and the page contexts of the top results,
28
- truncates the context to avoid exceeding the token limit, then sends it to the
29
- dedicated endpoint and returns only the generated answer.
30
- """
31
- # Combine the context from the top results (adjust the separator as needed)
32
  context = "\n\n".join([res.payload["page_content"] for res in top_results])
33
-
34
- # Truncate the context to a maximum number of characters (e.g., 12000 characters)
35
  max_context_chars = 15000
36
  if len(context) > max_context_chars:
37
  context = context[:max_context_chars]
38
-
39
- # Build the prompt, instructing the model to only output the final answer.
40
  prompt = (
41
  "Using the following context, answer the question concisely. "
42
  "Only output the final answer below, without repeating the context or question.\n\n"
@@ -44,37 +34,29 @@ def get_rag_answer(query, top_results):
44
  f"Question: {query}\n\n"
45
  "Answer:"
46
  )
47
-
48
  headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"}
49
  payload = {
50
  "inputs": prompt,
51
- "parameters": {
52
- "max_new_tokens": 150 # Adjust max tokens as needed
53
- }
54
  }
55
-
56
  response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload)
57
  if response.status_code == 200:
58
  result = response.json()
59
  answer = result[0]["generated_text"]
60
- # If the model returns the full prompt, split and extract only the portion after "Answer:"
61
  if "Answer:" in answer:
62
  answer = answer.split("Answer:")[-1].strip()
63
  return answer
64
  else:
65
  return f"Error in generating answer: {response.text}"
66
 
67
-
68
- #######
69
- # Helper function: Format project id (e.g., "201940485" -> "2019.4048.5")
70
  def format_project_id(pid):
71
  s = str(pid)
72
  if len(s) > 5:
73
  return s[:4] + "." + s[4:-1] + "." + s[-1]
74
  return s
75
 
76
-
77
- # Helper function: Compute title from metadata using name.en (or name.de if empty)
78
  def compute_title(metadata):
79
  name_en = metadata.get("name.en", "").strip()
80
  name_de = metadata.get("name.de", "").strip()
@@ -84,7 +66,7 @@ def compute_title(metadata):
84
  return f"{base} [{format_project_id(pid)}]"
85
  return base or "No Title"
86
 
87
- # Helper function: Get CRS filter options from all documents in the collection
88
  @st.cache_data
89
  def get_crs_options(_client, collection_name):
90
  results = hybrid_search(_client, "", collection_name)
@@ -99,8 +81,7 @@ def get_crs_options(_client, collection_name):
99
  crs_set.add(crs_combined)
100
  return sorted(crs_set)
101
 
102
-
103
- # Update filter_results to also filter by crs_combined.
104
  def filter_results(results, country_filter, region_filter, end_year_range, crs_filter):
105
  filtered = []
106
  for r in results:
@@ -128,30 +109,32 @@ def filter_results(results, country_filter, region_filter, end_year_range, crs_f
128
  else:
129
  countries_in_region = c_list
130
 
131
- # Filter by CRS: compute crs_combined and compare to the selected filter.
132
  crs_key = metadata.get("crs_key", "").strip()
133
  crs_value = metadata.get("crs_value", "").strip()
134
  crs_combined = f"{crs_key}: {crs_value}" if (crs_key or crs_value) else ""
135
 
136
- if crs_filter != "All/Not allocated" and crs_filter != crs_combined:
137
- continue
 
 
 
 
 
138
 
139
- if ((country_filter == "All/Not allocated" or selected_iso_code in c_list)
140
  and (region_filter == "All/Not allocated" or countries_in_region)
141
- and (end_year_range[0] <= end_year_val <= end_year_range[1])):
142
  filtered.append(r)
143
  return filtered
144
 
145
- #######
146
-
147
- # get the device to be used eithe gpu or cpu
148
  device = 'cuda' if cuda.is_available() else 'cpu'
149
 
150
-
151
- st.set_page_config(page_title="SEARCH IATI",layout='wide')
152
  st.title("GIZ Project Database (PROTOTYPE)")
153
  var = st.text_input("Enter Search Question")
154
 
 
155
  # Load the region lookup CSV
156
  region_lookup_path = "docStore/regions_lookup.csv"
157
  region_df = load_region_data(region_lookup_path)
@@ -196,14 +179,19 @@ def get_country_name_and_region_mapping(_client, collection_name, region_df):
196
 
197
  client = get_client()
198
  country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df)
199
- unique_country_names = sorted(country_name_mapping.keys()) # List of country names
200
 
201
- # Layout filters in columns: add a new filter for CRS in col4.
202
  col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
203
  with col1:
204
  region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
 
 
 
 
 
205
  with col2:
206
- country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names if (filtered_country_names := unique_country_names) else unique_country_names)
207
  with col3:
208
  current_year = datetime.now().year
209
  default_start_year = current_year - 4
@@ -212,46 +200,32 @@ with col4:
212
  crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
213
  crs_filter = st.selectbox("CRS", crs_options)
214
 
215
- # Checkbox to control whether to show only exact matches
216
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
217
 
218
-
219
-
220
- # Run the search
221
-
222
- # 1) Adjust limit so we get more than 15 results
223
- results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200
224
-
225
- # results is a tuple: (semantic_results, lexical_results)
226
  semantic_all = results[0]
227
  lexical_all = results[1]
228
 
229
- # 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score)
230
- semantic_all = [
231
- r for r in semantic_all if len(r.payload["page_content"]) >= 5
232
- ]
233
- lexical_all = [
234
- r for r in lexical_all if len(r.payload["page_content"]) >= 5
235
- ]
236
 
237
- # 2) Apply a threshold to SEMANTIC results (score >= 0.4)
238
  semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
239
 
240
-
241
  filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range, crs_filter)
242
  filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range, crs_filter)
243
- filtered_semantic_no_dupe = remove_duplicates(filtered_semantic) # ToDo remove duplicates again?
 
244
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
245
 
246
- # Define a helper function to format currency values
247
  def format_currency(value):
248
  try:
249
- # Convert to float then int for formatting (assumes whole numbers)
250
  return f"鈧瑊int(float(value)):,}"
251
  except (ValueError, TypeError):
252
  return value
253
-
254
- # Helper function to highlight query matches (case-insensitive)
255
  def highlight_query(text, query):
256
  pattern = re.compile(re.escape(query), re.IGNORECASE)
257
  return pattern.sub(lambda m: f"**{m.group(0)}**", text)
@@ -275,15 +249,12 @@ if show_exact_matches:
275
  st.divider()
276
  for res in top_results:
277
  metadata = res.payload.get('metadata', {})
278
- # Compute new title if not already set
279
  if "title" not in metadata:
280
  metadata["title"] = compute_title(metadata)
281
- # Use new title instead of project_name and highlight query if present
282
  display_title = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
283
  proj_id = metadata.get('id', 'Unknown')
284
  st.markdown(f"#### {display_title} [{proj_id}]")
285
 
286
- # Build snippet with potential highlighting
287
  objectives = metadata.get("objectives", "")
288
  desc_de = metadata.get("description.de", "")
289
  desc_en = metadata.get("description.en", "")
@@ -299,13 +270,11 @@ if show_exact_matches:
299
  with st.expander("Show more"):
300
  st.write(remainder_text)
301
 
302
- # Keywords
303
  full_text = res.payload['page_content']
304
  top_keywords = extract_top_keywords(full_text, top_n=5)
305
  if top_keywords:
306
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
307
 
308
- # Country info
309
  try:
310
  c_list = json.loads(metadata.get('countries', "[]").replace("'", '"'))
311
  except json.JSONDecodeError:
@@ -318,7 +287,6 @@ if show_exact_matches:
318
  matched_countries.append(resolved_name)
319
 
320
  additional_text = f"Country: **{', '.join(matched_countries) if matched_countries else 'Unknown'}**"
321
- # Add contact info if available and not [email protected]
322
  contact = metadata.get("contact", "").strip()
323
  if contact and contact.lower() != "[email protected]":
324
  additional_text += f" | Contact: **{contact}**"
@@ -380,7 +348,6 @@ else:
380
  additional_text += f" | Contact: **{contact}**"
381
  st.markdown(additional_text)
382
  st.divider()
383
-
384
  # for i in results:
385
  # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
386
  # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
 
23
  WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
24
 
25
  def get_rag_answer(query, top_results):
 
 
 
 
 
 
26
  context = "\n\n".join([res.payload["page_content"] for res in top_results])
 
 
27
  max_context_chars = 15000
28
  if len(context) > max_context_chars:
29
  context = context[:max_context_chars]
 
 
30
  prompt = (
31
  "Using the following context, answer the question concisely. "
32
  "Only output the final answer below, without repeating the context or question.\n\n"
 
34
  f"Question: {query}\n\n"
35
  "Answer:"
36
  )
 
37
  headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"}
38
  payload = {
39
  "inputs": prompt,
40
+ "parameters": {"max_new_tokens": 150}
 
 
41
  }
 
42
  response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload)
43
  if response.status_code == 200:
44
  result = response.json()
45
  answer = result[0]["generated_text"]
 
46
  if "Answer:" in answer:
47
  answer = answer.split("Answer:")[-1].strip()
48
  return answer
49
  else:
50
  return f"Error in generating answer: {response.text}"
51
 
52
+ # Helper: Format project id (e.g., "201940485" -> "2019.4048.5")
 
 
53
  def format_project_id(pid):
54
  s = str(pid)
55
  if len(s) > 5:
56
  return s[:4] + "." + s[4:-1] + "." + s[-1]
57
  return s
58
 
59
+ # Helper: Compute title from metadata using name.en (or name.de if empty)
 
60
  def compute_title(metadata):
61
  name_en = metadata.get("name.en", "").strip()
62
  name_de = metadata.get("name.de", "").strip()
 
66
  return f"{base} [{format_project_id(pid)}]"
67
  return base or "No Title"
68
 
69
+ # Helper: Get CRS filter options from all documents
70
  @st.cache_data
71
  def get_crs_options(_client, collection_name):
72
  results = hybrid_search(_client, "", collection_name)
 
81
  crs_set.add(crs_combined)
82
  return sorted(crs_set)
83
 
84
+ # Revised filter_results: Allow missing end_year or CRS; enforce CRS only when present.
 
85
  def filter_results(results, country_filter, region_filter, end_year_range, crs_filter):
86
  filtered = []
87
  for r in results:
 
109
  else:
110
  countries_in_region = c_list
111
 
 
112
  crs_key = metadata.get("crs_key", "").strip()
113
  crs_value = metadata.get("crs_value", "").strip()
114
  crs_combined = f"{crs_key}: {crs_value}" if (crs_key or crs_value) else ""
115
 
116
+ # Only enforce CRS filter if result has a CRS value.
117
+ if crs_filter != "All/Not allocated" and crs_combined:
118
+ if crs_filter != crs_combined:
119
+ continue
120
+
121
+ # Allow projects with no valid end_year to pass (if end_year_val is 0)
122
+ year_ok = True if end_year_val == 0 else (end_year_range[0] <= end_year_val <= end_year_range[1])
123
 
124
+ if ((country_filter == "All/Not allocated" or (selected_iso_code and selected_iso_code in c_list))
125
  and (region_filter == "All/Not allocated" or countries_in_region)
126
+ and year_ok):
127
  filtered.append(r)
128
  return filtered
129
 
130
+ # Get the device to be used (GPU or CPU)
 
 
131
  device = 'cuda' if cuda.is_available() else 'cpu'
132
 
133
+ st.set_page_config(page_title="SEARCH IATI", layout='wide')
 
134
  st.title("GIZ Project Database (PROTOTYPE)")
135
  var = st.text_input("Enter Search Question")
136
 
137
+
138
  # Load the region lookup CSV
139
  region_lookup_path = "docStore/regions_lookup.csv"
140
  region_df = load_region_data(region_lookup_path)
 
179
 
180
  client = get_client()
181
  country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df)
182
+ unique_country_names = sorted(country_name_mapping.keys())
183
 
184
+ # Layout filters in columns
185
  col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
186
  with col1:
187
  region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
188
+ # Compute filtered_country_names based on region_filter:
189
+ if region_filter == "All/Not allocated":
190
+ filtered_country_names = unique_country_names
191
+ else:
192
+ filtered_country_names = [name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter]
193
  with col2:
194
+ country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
195
  with col3:
196
  current_year = datetime.now().year
197
  default_start_year = current_year - 4
 
200
  crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
201
  crs_filter = st.selectbox("CRS", crs_options)
202
 
203
+ # Checkbox for exact matches
204
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
205
 
206
+ # Run the search
207
+ results = hybrid_search(client, var, collection_name, limit=500)
 
 
 
 
 
 
208
  semantic_all = results[0]
209
  lexical_all = results[1]
210
 
211
+ semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
212
+ lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
 
 
 
 
 
213
 
 
214
  semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
215
 
 
216
  filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range, crs_filter)
217
  filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range, crs_filter)
218
+
219
+ filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
220
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
221
 
 
222
  def format_currency(value):
223
  try:
 
224
  return f"鈧瑊int(float(value)):,}"
225
  except (ValueError, TypeError):
226
  return value
227
+
228
+ # Helper to highlight query matches (case-insensitive)
229
  def highlight_query(text, query):
230
  pattern = re.compile(re.escape(query), re.IGNORECASE)
231
  return pattern.sub(lambda m: f"**{m.group(0)}**", text)
 
249
  st.divider()
250
  for res in top_results:
251
  metadata = res.payload.get('metadata', {})
 
252
  if "title" not in metadata:
253
  metadata["title"] = compute_title(metadata)
 
254
  display_title = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
255
  proj_id = metadata.get('id', 'Unknown')
256
  st.markdown(f"#### {display_title} [{proj_id}]")
257
 
 
258
  objectives = metadata.get("objectives", "")
259
  desc_de = metadata.get("description.de", "")
260
  desc_en = metadata.get("description.en", "")
 
270
  with st.expander("Show more"):
271
  st.write(remainder_text)
272
 
 
273
  full_text = res.payload['page_content']
274
  top_keywords = extract_top_keywords(full_text, top_n=5)
275
  if top_keywords:
276
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
277
 
 
278
  try:
279
  c_list = json.loads(metadata.get('countries', "[]").replace("'", '"'))
280
  except json.JSONDecodeError:
 
287
  matched_countries.append(resolved_name)
288
 
289
  additional_text = f"Country: **{', '.join(matched_countries) if matched_countries else 'Unknown'}**"
 
290
  contact = metadata.get("contact", "").strip()
291
  if contact and contact.lower() != "[email protected]":
292
  additional_text += f" | Contact: **{contact}**"
 
348
  additional_text += f" | Contact: **{contact}**"
349
  st.markdown(additional_text)
350
  st.divider()
 
351
  # for i in results:
352
  # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
353
  # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")