Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,9 @@ import folium
|
|
5 |
from fastai.vision.all import *
|
6 |
from groq import Groq
|
7 |
from PIL import Image
|
|
|
|
|
|
|
8 |
|
9 |
# Load the trained model
|
10 |
learn = load_learner('export.pkl')
|
@@ -15,6 +18,9 @@ client = Groq(
|
|
15 |
api_key=os.environ.get("GROQ_API_KEY"),
|
16 |
)
|
17 |
|
|
|
|
|
|
|
18 |
# Language translations
|
19 |
translations = {
|
20 |
"en": {
|
@@ -34,7 +40,11 @@ translations = {
|
|
34 |
"answer_title": "Answer:",
|
35 |
"habitat_map_title": "Natural Habitat Map for",
|
36 |
"detailed_info_title": "Detailed Information",
|
37 |
-
"language_label": "Language / Lugha"
|
|
|
|
|
|
|
|
|
38 |
},
|
39 |
"sw": {
|
40 |
"app_title": "Mtafiti wa Ndege: Utambuzi wa Kiotomatiki kwa Watafiti",
|
@@ -53,38 +63,117 @@ translations = {
|
|
53 |
"answer_title": "Jibu:",
|
54 |
"habitat_map_title": "Ramani ya Makazi Asilia ya",
|
55 |
"detailed_info_title": "Taarifa za Kina",
|
56 |
-
"language_label": "Language / Lugha"
|
|
|
|
|
|
|
|
|
57 |
}
|
58 |
}
|
59 |
|
60 |
def clean_bird_name(name):
|
61 |
"""Clean bird name by removing numbers and special characters, and fix formatting"""
|
|
|
62 |
cleaned = re.sub(r'^\d+\.', '', name)
|
|
|
63 |
cleaned = cleaned.replace('_', ' ')
|
|
|
64 |
cleaned = re.sub(r'[^\w\s]', '', cleaned)
|
|
|
65 |
cleaned = ' '.join(cleaned.split())
|
66 |
return cleaned
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def get_bird_habitat_map(bird_name, check_tanzania=True):
|
69 |
-
"""Get habitat map locations for the bird using Groq API"""
|
70 |
clean_name = clean_bird_name(bird_name)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
try:
|
78 |
-
tanzania_check = client.chat.completions.create(
|
79 |
-
messages=[{"role": "user", "content": tanzania_check_prompt}],
|
80 |
-
model="llama-3.3-70b-versatile",
|
81 |
-
)
|
82 |
-
is_in_tanzania = "yes" in tanzania_check.choices[0].message.content.lower()
|
83 |
-
except:
|
84 |
-
is_in_tanzania = True
|
85 |
else:
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
prompt = f"""
|
89 |
Provide a JSON array of the main habitat locations for the {clean_name} bird in the world.
|
90 |
Return ONLY a JSON array with 3-5 entries, each containing:
|
@@ -104,78 +193,120 @@ def get_bird_habitat_map(bird_name, check_tanzania=True):
|
|
104 |
|
105 |
try:
|
106 |
chat_completion = client.chat.completions.create(
|
107 |
-
messages=[
|
|
|
|
|
|
|
|
|
|
|
108 |
model="llama-3.3-70b-versatile",
|
109 |
)
|
110 |
response = chat_completion.choices[0].message.content
|
111 |
|
|
|
|
|
|
|
|
|
|
|
112 |
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
113 |
if json_match:
|
114 |
locations = json.loads(json_match.group())
|
115 |
else:
|
|
|
116 |
locations = [
|
117 |
{"name": "Primary habitat region", "lat": 0, "lon": 0,
|
118 |
"description": "Could not retrieve specific habitat information for this bird."}
|
119 |
]
|
|
|
|
|
|
|
|
|
120 |
return locations, is_in_tanzania
|
|
|
121 |
except Exception as e:
|
122 |
return [{"name": "Error retrieving data", "lat": 0, "lon": 0,
|
123 |
"description": "Please try again or check your connection."}], False
|
124 |
|
125 |
def create_habitat_map(habitat_locations):
|
126 |
"""Create a folium map with the habitat locations"""
|
|
|
127 |
valid_coords = [(loc.get("lat", 0), loc.get("lon", 0))
|
128 |
for loc in habitat_locations
|
129 |
if loc.get("lat", 0) != 0 or loc.get("lon", 0) != 0]
|
130 |
|
131 |
if valid_coords:
|
|
|
132 |
avg_lat = sum(lat for lat, _ in valid_coords) / len(valid_coords)
|
133 |
avg_lon = sum(lon for _, lon in valid_coords) / len(valid_coords)
|
|
|
134 |
m = folium.Map(location=[avg_lat, avg_lon], zoom_start=3)
|
135 |
else:
|
|
|
136 |
m = folium.Map(location=[20, 0], zoom_start=2)
|
137 |
|
|
|
138 |
for location in habitat_locations:
|
139 |
name = location.get("name", "Unknown")
|
140 |
lat = location.get("lat", 0)
|
141 |
lon = location.get("lon", 0)
|
142 |
description = location.get("description", "No description available")
|
143 |
|
|
|
144 |
if lat == 0 and lon == 0:
|
145 |
continue
|
146 |
|
|
|
147 |
folium.Marker(
|
148 |
location=[lat, lon],
|
149 |
popup=folium.Popup(f"<b>{name}</b><br>{description}", max_width=300),
|
150 |
tooltip=name
|
151 |
).add_to(m)
|
152 |
|
|
|
153 |
map_html = m._repr_html_()
|
154 |
return map_html
|
155 |
|
156 |
def format_bird_info(raw_info, language="en"):
|
157 |
"""Improve the formatting of bird information"""
|
|
|
158 |
formatted = raw_info
|
|
|
|
|
159 |
warning_text = "NOT TYPICALLY FOUND IN TANZANIA"
|
160 |
warning_translation = "HAPATIKANI SANA TANZANIA" if language == "sw" else warning_text
|
161 |
|
|
|
162 |
formatted = re.sub(r'#+\s+' + warning_text,
|
163 |
-
|
164 |
-
|
165 |
|
|
|
166 |
formatted = re.sub(r'#+\s+(.*)', r'<h3>\1</h3>', formatted)
|
|
|
|
|
167 |
formatted = re.sub(r'\n\*\s+(.*)', r'<p>• \1</p>', formatted)
|
168 |
formatted = re.sub(r'\n([^<\n].*)', r'<p>\1</p>', formatted)
|
169 |
|
|
|
170 |
formatted = formatted.replace('<p><p>', '<p>')
|
171 |
formatted = formatted.replace('</p></p>', '</p>')
|
|
|
172 |
return formatted
|
173 |
|
174 |
def get_bird_info(bird_name, language="en"):
|
175 |
-
"""Get detailed information about a bird using Groq API"""
|
176 |
clean_name = clean_bird_name(bird_name)
|
177 |
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
prompt = f"""
|
181 |
Provide detailed information about the {clean_name} bird, including:
|
@@ -192,130 +323,184 @@ def get_bird_info(bird_name, language="en"):
|
|
192 |
|
193 |
try:
|
194 |
chat_completion = client.chat.completions.create(
|
195 |
-
messages=[
|
|
|
|
|
|
|
|
|
|
|
196 |
model="llama-3.3-70b-versatile",
|
197 |
)
|
198 |
-
|
|
|
|
|
|
|
199 |
except Exception as e:
|
200 |
error_msg = "Hitilafu katika kupata taarifa" if language == "sw" else "Error fetching information"
|
201 |
return f"{error_msg}: {str(e)}"
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def predict_and_get_info(img, language="en"):
|
204 |
"""Predict bird species and get detailed information"""
|
|
|
205 |
t = translations[language]
|
206 |
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
"Tafadhali pakia picha ya wazi ya ndege kwa utambuzi sahihi."
|
228 |
-
)
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
custom_css = """
|
231 |
<style>
|
232 |
.bird-container {
|
233 |
font-family: Arial, sans-serif;
|
234 |
padding: 10px;
|
235 |
}
|
236 |
-
.
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
padding: 10px;
|
241 |
margin-bottom: 15px;
|
242 |
border-radius: 4px;
|
243 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
</style>
|
245 |
"""
|
246 |
|
|
|
|
|
|
|
247 |
combined_info = f"""
|
248 |
{custom_css}
|
249 |
<div class="bird-container">
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
</div>
|
253 |
</div>
|
254 |
"""
|
255 |
-
return prediction_results, combined_info, ""
|
256 |
-
|
257 |
-
habitat_locations, is_in_tanzania = get_bird_habitat_map(top_bird)
|
258 |
-
habitat_map_html = create_habitat_map(habitat_locations)
|
259 |
-
|
260 |
-
bird_info = get_bird_info(top_bird, language)
|
261 |
-
formatted_info = format_bird_info(bird_info, language)
|
262 |
-
|
263 |
-
custom_css = """
|
264 |
-
<style>
|
265 |
-
.bird-container {
|
266 |
-
font-family: Arial, sans-serif;
|
267 |
-
padding: 10px;
|
268 |
-
}
|
269 |
-
.map-container {
|
270 |
-
height: 400px;
|
271 |
-
width: 100%;
|
272 |
-
border: 1px solid #ddd;
|
273 |
-
border-radius: 8px;
|
274 |
-
overflow: hidden;
|
275 |
-
margin-bottom: 20px;
|
276 |
-
}
|
277 |
-
.info-container {
|
278 |
-
line-height: 1.6;
|
279 |
-
}
|
280 |
-
.info-container h3 {
|
281 |
-
margin-top: 20px;
|
282 |
-
margin-bottom: 10px;
|
283 |
-
color: #2c3e50;
|
284 |
-
border-bottom: 1px solid #eee;
|
285 |
-
padding-bottom: 5px;
|
286 |
-
}
|
287 |
-
.info-container p {
|
288 |
-
margin-bottom: 10px;
|
289 |
-
}
|
290 |
-
.alert {
|
291 |
-
padding: 10px;
|
292 |
-
margin-bottom: 15px;
|
293 |
-
border-radius: 4px;
|
294 |
-
}
|
295 |
-
.alert-warning {
|
296 |
-
background-color: #fcf8e3;
|
297 |
-
border: 1px solid #faebcc;
|
298 |
-
color: #8a6d3b;
|
299 |
-
}
|
300 |
-
</style>
|
301 |
-
"""
|
302 |
-
|
303 |
-
combined_info = f"""
|
304 |
-
{custom_css}
|
305 |
-
<div class="bird-container">
|
306 |
-
<h2>{t['habitat_map_title']} {clean_top_bird}</h2>
|
307 |
-
<div class="map-container">
|
308 |
-
{habitat_map_html}
|
309 |
-
</div>
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
</div>
|
316 |
-
"""
|
317 |
-
|
318 |
-
return prediction_results, combined_info, clean_top_bird
|
319 |
|
320 |
def follow_up_question(question, bird_name, language="en"):
|
321 |
"""Allow researchers to ask follow-up questions about the identified bird"""
|
@@ -324,7 +509,16 @@ def follow_up_question(question, bird_name, language="en"):
|
|
324 |
if not question.strip() or not bird_name:
|
325 |
return "Please identify a bird first and ask a specific question about it." if language == "en" else "Tafadhali tambua ndege kwanza na uulize swali maalum kuhusu ndege huyo."
|
326 |
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
prompt = f"""
|
330 |
The researcher is asking about the {bird_name} bird: "{question}"
|
@@ -341,19 +535,29 @@ def follow_up_question(question, bird_name, language="en"):
|
|
341 |
|
342 |
try:
|
343 |
chat_completion = client.chat.completions.create(
|
344 |
-
messages=[
|
|
|
|
|
|
|
|
|
|
|
345 |
model="llama-3.3-70b-versatile",
|
346 |
)
|
347 |
-
|
|
|
|
|
|
|
348 |
except Exception as e:
|
349 |
error_msg = "Hitilafu katika kupata taarifa" if language == "sw" else "Error fetching information"
|
350 |
return f"{error_msg}: {str(e)}"
|
351 |
|
352 |
# Create the Gradio interface
|
353 |
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
|
354 |
current_lang = gr.State("en")
|
355 |
current_bird = gr.State("")
|
356 |
|
|
|
357 |
with gr.Row():
|
358 |
with gr.Column(scale=3):
|
359 |
title_md = gr.Markdown(f"# {translations['en']['app_title']}")
|
@@ -364,8 +568,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
364 |
value="English"
|
365 |
)
|
366 |
|
|
|
367 |
description_md = gr.Markdown(f"{translations['en']['app_description']}")
|
368 |
|
|
|
369 |
with gr.Row():
|
370 |
with gr.Column(scale=1):
|
371 |
input_image = gr.Image(type="pil", label=translations['en']['upload_label'])
|
@@ -375,8 +581,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
375 |
prediction_output = gr.Label(label=translations['en']['predictions_label'], num_top_classes=5)
|
376 |
bird_info_output = gr.HTML(label=translations['en']['bird_info_label'])
|
377 |
|
|
|
378 |
gr.Markdown("---")
|
379 |
|
|
|
380 |
questions_header = gr.Markdown(f"## {translations['en']['research_questions']}")
|
381 |
|
382 |
conversation_history = gr.Markdown("")
|
@@ -392,17 +600,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
392 |
follow_up_btn = gr.Button(translations['en']['submit_question'], variant="primary")
|
393 |
clear_btn = gr.Button(translations['en']['clear_conversation'])
|
394 |
|
395 |
-
|
396 |
-
if img is None:
|
397 |
-
return None, translations[lang]['upload_prompt'], "", ""
|
398 |
-
|
399 |
-
try:
|
400 |
-
pred_results, info, clean_bird_name = predict_and_get_info(img, lang)
|
401 |
-
return pred_results, info, clean_bird_name, ""
|
402 |
-
except Exception as e:
|
403 |
-
error_msg = "Hitilafu katika kuchakata picha" if lang == "sw" else "Error processing image"
|
404 |
-
return None, f"{error_msg}: {str(e)}", "", ""
|
405 |
-
|
406 |
def update_conversation(question, bird_name, history, lang):
|
407 |
t = translations[lang]
|
408 |
|
@@ -411,6 +609,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
411 |
|
412 |
answer = follow_up_question(question, bird_name, lang)
|
413 |
|
|
|
414 |
new_exchange = f"""
|
415 |
### {t['question_title']}
|
416 |
{question}
|
@@ -425,9 +624,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
425 |
return ""
|
426 |
|
427 |
def update_language(choice):
|
|
|
428 |
lang = "sw" if choice == "Kiswahili" else "en"
|
429 |
t = translations[lang]
|
430 |
|
|
|
431 |
return (
|
432 |
lang,
|
433 |
f"# {t['app_title']}",
|
@@ -443,6 +644,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
443 |
t['clear_conversation']
|
444 |
)
|
445 |
|
|
|
446 |
language_selector.change(
|
447 |
update_language,
|
448 |
inputs=[language_selector],
|
@@ -462,8 +664,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
462 |
]
|
463 |
)
|
464 |
|
|
|
465 |
submit_btn.click(
|
466 |
-
|
|
|
|
|
|
|
|
|
467 |
inputs=[input_image, current_lang],
|
468 |
outputs=[prediction_output, bird_info_output, current_bird, conversation_history]
|
469 |
)
|
|
|
5 |
from fastai.vision.all import *
|
6 |
from groq import Groq
|
7 |
from PIL import Image
|
8 |
+
import time
|
9 |
+
import json
|
10 |
+
from functools import lru_cache
|
11 |
|
12 |
# Load the trained model
|
13 |
learn = load_learner('export.pkl')
|
|
|
18 |
api_key=os.environ.get("GROQ_API_KEY"),
|
19 |
)
|
20 |
|
21 |
+
# Cache directory for API responses
|
22 |
+
os.makedirs("cache", exist_ok=True)
|
23 |
+
|
24 |
# Language translations
|
25 |
translations = {
|
26 |
"en": {
|
|
|
40 |
"answer_title": "Answer:",
|
41 |
"habitat_map_title": "Natural Habitat Map for",
|
42 |
"detailed_info_title": "Detailed Information",
|
43 |
+
"language_label": "Language / Lugha",
|
44 |
+
"loading": "Processing...",
|
45 |
+
"low_confidence": "Low confidence prediction. Results may not be accurate.",
|
46 |
+
"not_a_bird": "The image may not contain a bird. Please upload a clear image of a bird.",
|
47 |
+
"other_message": "This bird is not in our trained dataset or the image may not be of a bird. Please try uploading a different image."
|
48 |
},
|
49 |
"sw": {
|
50 |
"app_title": "Mtafiti wa Ndege: Utambuzi wa Kiotomatiki kwa Watafiti",
|
|
|
63 |
"answer_title": "Jibu:",
|
64 |
"habitat_map_title": "Ramani ya Makazi Asilia ya",
|
65 |
"detailed_info_title": "Taarifa za Kina",
|
66 |
+
"language_label": "Language / Lugha",
|
67 |
+
"loading": "Inachakata...",
|
68 |
+
"low_confidence": "Utabiri wa uhakika mdogo. Matokeo yanaweza kuwa si sahihi.",
|
69 |
+
"not_a_bird": "Picha inaweza isiwe ya ndege. Tafadhali pakia picha wazi ya ndege.",
|
70 |
+
"other_message": "Ndege huyu haipatikani katika hifadhidata yetu au picha inaweza isiwe ya ndege. Tafadhali jaribu kupakia picha nyingine."
|
71 |
}
|
72 |
}
|
73 |
|
74 |
def clean_bird_name(name):
|
75 |
"""Clean bird name by removing numbers and special characters, and fix formatting"""
|
76 |
+
# Remove numbers and dots at the beginning
|
77 |
cleaned = re.sub(r'^\d+\.', '', name)
|
78 |
+
# Replace underscores with spaces
|
79 |
cleaned = cleaned.replace('_', ' ')
|
80 |
+
# Remove any remaining special characters
|
81 |
cleaned = re.sub(r'[^\w\s]', '', cleaned)
|
82 |
+
# Fix spacing
|
83 |
cleaned = ' '.join(cleaned.split())
|
84 |
return cleaned
|
85 |
|
86 |
+
def get_cache_path(function_name, key):
|
87 |
+
"""Generate a cache file path"""
|
88 |
+
safe_key = re.sub(r'[^\w]', '_', key)
|
89 |
+
return f"cache/{function_name}_{safe_key}.json"
|
90 |
+
|
91 |
+
def save_to_cache(function_name, key, data):
|
92 |
+
"""Save API response to cache"""
|
93 |
+
try:
|
94 |
+
cache_path = get_cache_path(function_name, key)
|
95 |
+
with open(cache_path, 'w') as f:
|
96 |
+
json.dump({"data": data, "timestamp": time.time()}, f)
|
97 |
+
except Exception as e:
|
98 |
+
print(f"Error saving to cache: {e}")
|
99 |
+
|
100 |
+
def load_from_cache(function_name, key, max_age=86400): # Default max age: 1 day
|
101 |
+
"""Load API response from cache if it exists and is not too old"""
|
102 |
+
try:
|
103 |
+
cache_path = get_cache_path(function_name, key)
|
104 |
+
if os.path.exists(cache_path):
|
105 |
+
with open(cache_path, 'r') as f:
|
106 |
+
cached = json.load(f)
|
107 |
+
if time.time() - cached["timestamp"] < max_age:
|
108 |
+
return cached["data"]
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Error loading from cache: {e}")
|
111 |
+
return None
|
112 |
+
|
113 |
+
def is_likely_bird_image(img):
|
114 |
+
"""Basic check to see if the image might contain a bird"""
|
115 |
+
try:
|
116 |
+
# Convert to numpy array for analysis
|
117 |
+
img_array = np.array(img)
|
118 |
+
|
119 |
+
# Simple checks that might indicate a bird isn't present:
|
120 |
+
# 1. Check if image is too dark or too bright overall
|
121 |
+
mean_brightness = np.mean(img_array)
|
122 |
+
if mean_brightness < 20 or mean_brightness > 235:
|
123 |
+
return False
|
124 |
+
|
125 |
+
# 2. Check if image has very little color variation (might be a solid background)
|
126 |
+
std_dev = np.std(img_array)
|
127 |
+
if std_dev < 15:
|
128 |
+
return False
|
129 |
+
|
130 |
+
# 3. If image is very small, it might not be a useful bird photo
|
131 |
+
if img_array.shape[0] < 100 or img_array.shape[1] < 100:
|
132 |
+
return False
|
133 |
+
|
134 |
+
return True
|
135 |
+
except:
|
136 |
+
# If any error occurs during the check, assume it might be a bird
|
137 |
+
return True
|
138 |
+
|
139 |
def get_bird_habitat_map(bird_name, check_tanzania=True):
|
140 |
+
"""Get habitat map locations for the bird using Groq API with caching"""
|
141 |
clean_name = clean_bird_name(bird_name)
|
142 |
|
143 |
+
# Check cache for Tanzania check
|
144 |
+
tanzania_cache_key = f"{clean_name}_tanzania"
|
145 |
+
cached_tanzania = load_from_cache("tanzania_check", tanzania_cache_key)
|
146 |
+
if cached_tanzania is not None:
|
147 |
+
is_in_tanzania = cached_tanzania
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
else:
|
149 |
+
# First check if the bird is endemic to Tanzania
|
150 |
+
if check_tanzania:
|
151 |
+
tanzania_check_prompt = f"""
|
152 |
+
Is the {clean_name} bird native to or commonly found in Tanzania?
|
153 |
+
Answer with ONLY "yes" or "no".
|
154 |
+
"""
|
155 |
+
|
156 |
+
try:
|
157 |
+
tanzania_check = client.chat.completions.create(
|
158 |
+
messages=[{"role": "user", "content": tanzania_check_prompt}],
|
159 |
+
model="llama-3.3-70b-versatile",
|
160 |
+
)
|
161 |
+
is_in_tanzania = "yes" in tanzania_check.choices[0].message.content.lower()
|
162 |
+
# Cache result
|
163 |
+
save_to_cache("tanzania_check", tanzania_cache_key, is_in_tanzania)
|
164 |
+
except:
|
165 |
+
# Default to showing Tanzania if we can't determine
|
166 |
+
is_in_tanzania = True
|
167 |
+
else:
|
168 |
+
is_in_tanzania = True
|
169 |
|
170 |
+
# Check cache for habitat locations
|
171 |
+
habitat_cache_key = f"{clean_name}_habitat"
|
172 |
+
cached_habitat = load_from_cache("habitat", habitat_cache_key)
|
173 |
+
if cached_habitat is not None:
|
174 |
+
return cached_habitat, is_in_tanzania
|
175 |
+
|
176 |
+
# Now get the habitat locations
|
177 |
prompt = f"""
|
178 |
Provide a JSON array of the main habitat locations for the {clean_name} bird in the world.
|
179 |
Return ONLY a JSON array with 3-5 entries, each containing:
|
|
|
193 |
|
194 |
try:
|
195 |
chat_completion = client.chat.completions.create(
|
196 |
+
messages=[
|
197 |
+
{
|
198 |
+
"role": "user",
|
199 |
+
"content": prompt,
|
200 |
+
}
|
201 |
+
],
|
202 |
model="llama-3.3-70b-versatile",
|
203 |
)
|
204 |
response = chat_completion.choices[0].message.content
|
205 |
|
206 |
+
# Extract JSON from response (in case there's additional text)
|
207 |
+
import json
|
208 |
+
import re
|
209 |
+
|
210 |
+
# Find JSON pattern in response
|
211 |
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
212 |
if json_match:
|
213 |
locations = json.loads(json_match.group())
|
214 |
else:
|
215 |
+
# Fallback if JSON extraction fails
|
216 |
locations = [
|
217 |
{"name": "Primary habitat region", "lat": 0, "lon": 0,
|
218 |
"description": "Could not retrieve specific habitat information for this bird."}
|
219 |
]
|
220 |
+
|
221 |
+
# Cache the result
|
222 |
+
save_to_cache("habitat", habitat_cache_key, locations)
|
223 |
+
|
224 |
return locations, is_in_tanzania
|
225 |
+
|
226 |
except Exception as e:
|
227 |
return [{"name": "Error retrieving data", "lat": 0, "lon": 0,
|
228 |
"description": "Please try again or check your connection."}], False
|
229 |
|
230 |
def create_habitat_map(habitat_locations):
|
231 |
"""Create a folium map with the habitat locations"""
|
232 |
+
# Find center point based on valid coordinates
|
233 |
valid_coords = [(loc.get("lat", 0), loc.get("lon", 0))
|
234 |
for loc in habitat_locations
|
235 |
if loc.get("lat", 0) != 0 or loc.get("lon", 0) != 0]
|
236 |
|
237 |
if valid_coords:
|
238 |
+
# Calculate the average of the coordinates
|
239 |
avg_lat = sum(lat for lat, _ in valid_coords) / len(valid_coords)
|
240 |
avg_lon = sum(lon for _, lon in valid_coords) / len(valid_coords)
|
241 |
+
# Create map centered on the average coordinates
|
242 |
m = folium.Map(location=[avg_lat, avg_lon], zoom_start=3)
|
243 |
else:
|
244 |
+
# Default world map if no valid coordinates
|
245 |
m = folium.Map(location=[20, 0], zoom_start=2)
|
246 |
|
247 |
+
# Add markers for each habitat location
|
248 |
for location in habitat_locations:
|
249 |
name = location.get("name", "Unknown")
|
250 |
lat = location.get("lat", 0)
|
251 |
lon = location.get("lon", 0)
|
252 |
description = location.get("description", "No description available")
|
253 |
|
254 |
+
# Skip invalid coordinates
|
255 |
if lat == 0 and lon == 0:
|
256 |
continue
|
257 |
|
258 |
+
# Add marker
|
259 |
folium.Marker(
|
260 |
location=[lat, lon],
|
261 |
popup=folium.Popup(f"<b>{name}</b><br>{description}", max_width=300),
|
262 |
tooltip=name
|
263 |
).add_to(m)
|
264 |
|
265 |
+
# Save map to HTML
|
266 |
map_html = m._repr_html_()
|
267 |
return map_html
|
268 |
|
269 |
def format_bird_info(raw_info, language="en"):
|
270 |
"""Improve the formatting of bird information"""
|
271 |
+
# Add proper line breaks between sections and ensure consistent heading levels
|
272 |
formatted = raw_info
|
273 |
+
|
274 |
+
# Get translation of warning text based on language
|
275 |
warning_text = "NOT TYPICALLY FOUND IN TANZANIA"
|
276 |
warning_translation = "HAPATIKANI SANA TANZANIA" if language == "sw" else warning_text
|
277 |
|
278 |
+
# Fix heading levels (make all main sections h3)
|
279 |
formatted = re.sub(r'#+\s+' + warning_text,
|
280 |
+
f'<div class="alert alert-warning"><strong>⚠️ {warning_translation}</strong></div>',
|
281 |
+
formatted)
|
282 |
|
283 |
+
# Replace markdown headings with HTML headings for better control
|
284 |
formatted = re.sub(r'#+\s+(.*)', r'<h3>\1</h3>', formatted)
|
285 |
+
|
286 |
+
# Add paragraph tags for better spacing
|
287 |
formatted = re.sub(r'\n\*\s+(.*)', r'<p>• \1</p>', formatted)
|
288 |
formatted = re.sub(r'\n([^<\n].*)', r'<p>\1</p>', formatted)
|
289 |
|
290 |
+
# Remove any duplicate paragraph tags
|
291 |
formatted = formatted.replace('<p><p>', '<p>')
|
292 |
formatted = formatted.replace('</p></p>', '</p>')
|
293 |
+
|
294 |
return formatted
|
295 |
|
296 |
def get_bird_info(bird_name, language="en"):
|
297 |
+
"""Get detailed information about a bird using Groq API with caching"""
|
298 |
clean_name = clean_bird_name(bird_name)
|
299 |
|
300 |
+
# Check cache first
|
301 |
+
cache_key = f"{clean_name}_{language}"
|
302 |
+
cached_info = load_from_cache("bird_info", cache_key)
|
303 |
+
if cached_info is not None:
|
304 |
+
return cached_info
|
305 |
+
|
306 |
+
# Adjust language for the prompt
|
307 |
+
lang_instruction = ""
|
308 |
+
if language == "sw":
|
309 |
+
lang_instruction = " Provide your response in Swahili language."
|
310 |
|
311 |
prompt = f"""
|
312 |
Provide detailed information about the {clean_name} bird, including:
|
|
|
323 |
|
324 |
try:
|
325 |
chat_completion = client.chat.completions.create(
|
326 |
+
messages=[
|
327 |
+
{
|
328 |
+
"role": "user",
|
329 |
+
"content": prompt,
|
330 |
+
}
|
331 |
+
],
|
332 |
model="llama-3.3-70b-versatile",
|
333 |
)
|
334 |
+
response = chat_completion.choices[0].message.content
|
335 |
+
# Cache the result
|
336 |
+
save_to_cache("bird_info", cache_key, response)
|
337 |
+
return response
|
338 |
except Exception as e:
|
339 |
error_msg = "Hitilafu katika kupata taarifa" if language == "sw" else "Error fetching information"
|
340 |
return f"{error_msg}: {str(e)}"
|
341 |
|
342 |
+
def create_message_html(message, icon="🔍", language="en"):
|
343 |
+
"""Create a styled message container for notifications"""
|
344 |
+
custom_css = """
|
345 |
+
<style>
|
346 |
+
.message-container {
|
347 |
+
font-family: Arial, sans-serif;
|
348 |
+
padding:
|
349 |
+
20px;
|
350 |
+
background-color: #f8f9fa;
|
351 |
+
border-radius: 8px;
|
352 |
+
text-align: center;
|
353 |
+
margin: 20px 0;
|
354 |
+
}
|
355 |
+
.message-icon {
|
356 |
+
font-size: 48px;
|
357 |
+
margin-bottom: 15px;
|
358 |
+
}
|
359 |
+
.message-text {
|
360 |
+
font-size: 18px;
|
361 |
+
color: #495057;
|
362 |
+
}
|
363 |
+
</style>
|
364 |
+
"""
|
365 |
+
|
366 |
+
html = f"""
|
367 |
+
{custom_css}
|
368 |
+
<div class="message-container">
|
369 |
+
<div class="message-icon">{icon}</div>
|
370 |
+
<div class="message-text">{message}</div>
|
371 |
+
</div>
|
372 |
+
"""
|
373 |
+
return html
|
374 |
+
|
375 |
def predict_and_get_info(img, language="en"):
|
376 |
"""Predict bird species and get detailed information"""
|
377 |
+
# Get translations
|
378 |
t = translations[language]
|
379 |
|
380 |
+
# Check if an image was provided
|
381 |
+
if img is None:
|
382 |
+
message = t['upload_prompt']
|
383 |
+
return None, create_message_html(message, "📷", language), "", ""
|
384 |
|
385 |
+
# Basic check if the image might contain a bird
|
386 |
+
if not is_likely_bird_image(img):
|
387 |
+
message = t['not_a_bird']
|
388 |
+
return None, create_message_html(message, "⚠️", language), "", ""
|
389 |
|
390 |
+
try:
|
391 |
+
# Process the image
|
392 |
+
img = PILImage.create(img)
|
393 |
+
|
394 |
+
# Get prediction
|
395 |
+
pred, pred_idx, probs = learn.predict(img)
|
396 |
+
|
397 |
+
# Get top 5 predictions (or all if less than 5)
|
398 |
+
num_classes = min(5, len(labels))
|
399 |
+
top_indices = probs.argsort(descending=True)[:num_classes]
|
400 |
+
top_probs = probs[top_indices]
|
401 |
+
top_labels = [labels[i] for i in top_indices]
|
|
|
|
|
402 |
|
403 |
+
# Format as dictionary with cleaned names for display
|
404 |
+
prediction_results = {clean_bird_name(top_labels[i]): float(top_probs[i]) for i in range(num_classes)}
|
405 |
+
|
406 |
+
# Get top prediction (original format for info retrieval)
|
407 |
+
top_bird = str(pred)
|
408 |
+
# Also keep a clean version for display
|
409 |
+
clean_top_bird = clean_bird_name(top_bird)
|
410 |
+
|
411 |
+
# Check if the model's confidence is low
|
412 |
+
if float(top_probs[0]) < 0.4:
|
413 |
+
low_confidence_warning = t['low_confidence']
|
414 |
+
else:
|
415 |
+
low_confidence_warning = ""
|
416 |
+
|
417 |
+
# Check if the top prediction is "Other" and has high confidence
|
418 |
+
if "other" in clean_top_bird.lower():
|
419 |
+
# Create a message informing the user that the bird wasn't recognized
|
420 |
+
other_message = t['other_message']
|
421 |
+
combined_info = create_message_html(other_message, "🔍", language)
|
422 |
+
return prediction_results, combined_info, clean_top_bird, ""
|
423 |
+
|
424 |
+
# Get habitat locations and create map
|
425 |
+
habitat_locations, is_in_tanzania = get_bird_habitat_map(top_bird)
|
426 |
+
habitat_map_html = create_habitat_map(habitat_locations)
|
427 |
+
|
428 |
+
# Get detailed information about the top predicted bird
|
429 |
+
bird_info = get_bird_info(top_bird, language)
|
430 |
+
formatted_info = format_bird_info(bird_info, language)
|
431 |
+
|
432 |
+
# Create combined info with map at the top and properly formatted information
|
433 |
custom_css = """
|
434 |
<style>
|
435 |
.bird-container {
|
436 |
font-family: Arial, sans-serif;
|
437 |
padding: 10px;
|
438 |
}
|
439 |
+
.map-container {
|
440 |
+
height: 400px;
|
441 |
+
width: 100%;
|
442 |
+
border: 1px solid #ddd;
|
443 |
+
border-radius: 8px;
|
444 |
+
overflow: hidden;
|
445 |
+
margin-bottom: 20px;
|
446 |
+
}
|
447 |
+
.info-container {
|
448 |
+
line-height: 1.6;
|
449 |
+
}
|
450 |
+
.info-container h3 {
|
451 |
+
margin-top: 20px;
|
452 |
+
margin-bottom: 10px;
|
453 |
+
color: #2c3e50;
|
454 |
+
border-bottom: 1px solid #eee;
|
455 |
+
padding-bottom: 5px;
|
456 |
+
}
|
457 |
+
.info-container p {
|
458 |
+
margin-bottom: 10px;
|
459 |
+
}
|
460 |
+
.alert {
|
461 |
padding: 10px;
|
462 |
margin-bottom: 15px;
|
463 |
border-radius: 4px;
|
464 |
}
|
465 |
+
.alert-warning {
|
466 |
+
background-color: #fcf8e3;
|
467 |
+
border: 1px solid #faebcc;
|
468 |
+
color: #8a6d3b;
|
469 |
+
}
|
470 |
+
.confidence-warning {
|
471 |
+
background-color: #fff3cd;
|
472 |
+
color: #856404;
|
473 |
+
padding: 8px;
|
474 |
+
border-radius: 4px;
|
475 |
+
margin-bottom: 15px;
|
476 |
+
font-weight: bold;
|
477 |
+
}
|
478 |
</style>
|
479 |
"""
|
480 |
|
481 |
+
# Add low confidence warning if needed
|
482 |
+
confidence_warning_html = f'<div class="confidence-warning">{low_confidence_warning}</div>' if low_confidence_warning else ''
|
483 |
+
|
484 |
combined_info = f"""
|
485 |
{custom_css}
|
486 |
<div class="bird-container">
|
487 |
+
{confidence_warning_html}
|
488 |
+
<h2>{t['habitat_map_title']} {clean_top_bird}</h2>
|
489 |
+
<div class="map-container">
|
490 |
+
{habitat_map_html}
|
491 |
+
</div>
|
492 |
+
|
493 |
+
<div class="info-container">
|
494 |
+
<h2>{t['detailed_info_title']}</h2>
|
495 |
+
{formatted_info}
|
496 |
</div>
|
497 |
</div>
|
498 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
+
return prediction_results, combined_info, clean_top_bird, ""
|
501 |
+
except Exception as e:
|
502 |
+
error_msg = "Hitilafu katika kuchakata picha" if language == "sw" else "Error processing image"
|
503 |
+
return None, create_message_html(f"{error_msg}: {str(e)}", "⚠️", language), "", ""
|
|
|
|
|
|
|
|
|
504 |
|
505 |
def follow_up_question(question, bird_name, language="en"):
|
506 |
"""Allow researchers to ask follow-up questions about the identified bird"""
|
|
|
509 |
if not question.strip() or not bird_name:
|
510 |
return "Please identify a bird first and ask a specific question about it." if language == "en" else "Tafadhali tambua ndege kwanza na uulize swali maalum kuhusu ndege huyo."
|
511 |
|
512 |
+
# Check cache first
|
513 |
+
cache_key = f"{bird_name}_{question}_{language}".replace(" ", "_")[:100] # Limit key length
|
514 |
+
cached_answer = load_from_cache("follow_up", cache_key)
|
515 |
+
if cached_answer is not None:
|
516 |
+
return cached_answer
|
517 |
+
|
518 |
+
# Adjust language for the prompt
|
519 |
+
lang_instruction = ""
|
520 |
+
if language == "sw":
|
521 |
+
lang_instruction = " Provide your response in Swahili language."
|
522 |
|
523 |
prompt = f"""
|
524 |
The researcher is asking about the {bird_name} bird: "{question}"
|
|
|
535 |
|
536 |
try:
|
537 |
chat_completion = client.chat.completions.create(
|
538 |
+
messages=[
|
539 |
+
{
|
540 |
+
"role": "user",
|
541 |
+
"content": prompt,
|
542 |
+
}
|
543 |
+
],
|
544 |
model="llama-3.3-70b-versatile",
|
545 |
)
|
546 |
+
response = chat_completion.choices[0].message.content
|
547 |
+
# Cache the result
|
548 |
+
save_to_cache("follow_up", cache_key, response)
|
549 |
+
return response
|
550 |
except Exception as e:
|
551 |
error_msg = "Hitilafu katika kupata taarifa" if language == "sw" else "Error fetching information"
|
552 |
return f"{error_msg}: {str(e)}"
|
553 |
|
554 |
# Create the Gradio interface
|
555 |
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
556 |
+
# Current language and bird state
|
557 |
current_lang = gr.State("en")
|
558 |
current_bird = gr.State("")
|
559 |
|
560 |
+
# Header with language switcher
|
561 |
with gr.Row():
|
562 |
with gr.Column(scale=3):
|
563 |
title_md = gr.Markdown(f"# {translations['en']['app_title']}")
|
|
|
568 |
value="English"
|
569 |
)
|
570 |
|
571 |
+
# App description
|
572 |
description_md = gr.Markdown(f"{translations['en']['app_description']}")
|
573 |
|
574 |
+
# Main identification section
|
575 |
with gr.Row():
|
576 |
with gr.Column(scale=1):
|
577 |
input_image = gr.Image(type="pil", label=translations['en']['upload_label'])
|
|
|
581 |
prediction_output = gr.Label(label=translations['en']['predictions_label'], num_top_classes=5)
|
582 |
bird_info_output = gr.HTML(label=translations['en']['bird_info_label'])
|
583 |
|
584 |
+
# Clear divider
|
585 |
gr.Markdown("---")
|
586 |
|
587 |
+
# Follow-up question section with improved UI
|
588 |
questions_header = gr.Markdown(f"## {translations['en']['research_questions']}")
|
589 |
|
590 |
conversation_history = gr.Markdown("")
|
|
|
600 |
follow_up_btn = gr.Button(translations['en']['submit_question'], variant="primary")
|
601 |
clear_btn = gr.Button(translations['en']['clear_conversation'])
|
602 |
|
603 |
+
# Functions for event handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
def update_conversation(question, bird_name, history, lang):
|
605 |
t = translations[lang]
|
606 |
|
|
|
609 |
|
610 |
answer = follow_up_question(question, bird_name, lang)
|
611 |
|
612 |
+
# Format the conversation with clear separation
|
613 |
new_exchange = f"""
|
614 |
### {t['question_title']}
|
615 |
{question}
|
|
|
624 |
return ""
|
625 |
|
626 |
def update_language(choice):
|
627 |
+
# Convert selection to language code
|
628 |
lang = "sw" if choice == "Kiswahili" else "en"
|
629 |
t = translations[lang]
|
630 |
|
631 |
+
# Return updated UI components based on selected language
|
632 |
return (
|
633 |
lang,
|
634 |
f"# {t['app_title']}",
|
|
|
644 |
t['clear_conversation']
|
645 |
)
|
646 |
|
647 |
+
# Set up event handlers
|
648 |
language_selector.change(
|
649 |
update_language,
|
650 |
inputs=[language_selector],
|
|
|
664 |
]
|
665 |
)
|
666 |
|
667 |
+
# Add loading state for better UX
|
668 |
submit_btn.click(
|
669 |
+
lambda x, y: (None, create_message_html(translations[y]['loading'], "⏳", y), "", ""),
|
670 |
+
inputs=[input_image, current_lang],
|
671 |
+
outputs=[prediction_output, bird_info_output, current_bird, conversation_history]
|
672 |
+
).then(
|
673 |
+
predict_and_get_info,
|
674 |
inputs=[input_image, current_lang],
|
675 |
outputs=[prediction_output, bird_info_output, current_bird, conversation_history]
|
676 |
)
|