lyimo commited on
Commit
c30908d
·
verified ·
1 Parent(s): 1b8384b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +402 -7
app.py CHANGED
@@ -1,13 +1,408 @@
1
- import timm
 
 
 
2
  from fastai.vision.all import *
 
 
3
 
 
4
  learn = load_learner('export.pkl')
5
-
6
  labels = learn.dls.vocab
7
- def predict(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  img = PILImage.create(img)
9
- pred,pred_idx,probs = learn.predict(img)
10
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- import gradio as gr
13
- gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=3)).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import re
4
+ import folium
5
  from fastai.vision.all import *
6
+ from groq import Groq
7
+ from PIL import Image
8
 
9
+ # Load the trained model
10
  learn = load_learner('export.pkl')
 
11
  labels = learn.dls.vocab
12
+
13
+ # Initialize Groq client
14
+ client = Groq(
15
+ api_key=os.environ.get("GROQ_API_KEY"),
16
+ )
17
+
18
+ def clean_bird_name(name):
19
+ """Clean bird name by removing numbers and special characters, and fix formatting"""
20
+ # Remove numbers and dots at the beginning
21
+ cleaned = re.sub(r'^\d+\.', '', name)
22
+ # Replace underscores with spaces
23
+ cleaned = cleaned.replace('_', ' ')
24
+ # Remove any remaining special characters
25
+ cleaned = re.sub(r'[^\w\s]', '', cleaned)
26
+ # Fix spacing
27
+ cleaned = ' '.join(cleaned.split())
28
+ return cleaned
29
+
30
+ def get_bird_habitat_map(bird_name, check_tanzania=True):
31
+ """Get habitat map locations for the bird using Groq API"""
32
+ clean_name = clean_bird_name(bird_name)
33
+
34
+ # First check if the bird is endemic to Tanzania
35
+ if check_tanzania:
36
+ tanzania_check_prompt = f"""
37
+ Is the {clean_name} bird native to or commonly found in Tanzania?
38
+ Answer with ONLY "yes" or "no".
39
+ """
40
+
41
+ try:
42
+ tanzania_check = client.chat.completions.create(
43
+ messages=[{"role": "user", "content": tanzania_check_prompt}],
44
+ model="llama-3.3-70b-versatile",
45
+ )
46
+ is_in_tanzania = "yes" in tanzania_check.choices[0].message.content.lower()
47
+ except:
48
+ # Default to showing Tanzania if we can't determine
49
+ is_in_tanzania = True
50
+ else:
51
+ is_in_tanzania = True
52
+
53
+ # Now get the habitat locations
54
+ prompt = f"""
55
+ Provide a JSON array of the main habitat locations for the {clean_name} bird in the world.
56
+ Return ONLY a JSON array with 3-5 entries, each containing:
57
+ 1. "name": Location name
58
+ 2. "lat": Latitude (numeric value)
59
+ 3. "lon": Longitude (numeric value)
60
+ 4. "description": Brief description of why this is a key habitat (2-3 sentences)
61
+
62
+ Example format:
63
+ [
64
+ {{"name": "Example Location", "lat": 12.34, "lon": 56.78, "description": "Brief description"}},
65
+ ...
66
+ ]
67
+
68
+ {'' if is_in_tanzania else 'DO NOT include any locations in Tanzania as this bird is not native to or commonly found there.'}
69
+ """
70
+
71
+ try:
72
+ chat_completion = client.chat.completions.create(
73
+ messages=[
74
+ {
75
+ "role": "user",
76
+ "content": prompt,
77
+ }
78
+ ],
79
+ model="llama-3.3-70b-versatile",
80
+ )
81
+ response = chat_completion.choices[0].message.content
82
+
83
+ # Extract JSON from response (in case there's additional text)
84
+ import json
85
+ import re
86
+
87
+ # Find JSON pattern in response
88
+ json_match = re.search(r'\[.*\]', response, re.DOTALL)
89
+ if json_match:
90
+ locations = json.loads(json_match.group())
91
+ else:
92
+ # Fallback if JSON extraction fails
93
+ locations = [
94
+ {"name": "Primary habitat region", "lat": 0, "lon": 0,
95
+ "description": "Could not retrieve specific habitat information for this bird."}
96
+ ]
97
+
98
+ return locations, is_in_tanzania
99
+
100
+ except Exception as e:
101
+ return [{"name": "Error retrieving data", "lat": 0, "lon": 0,
102
+ "description": "Please try again or check your connection."}], False
103
+
104
+ def create_habitat_map(habitat_locations):
105
+ """Create a folium map with the habitat locations"""
106
+ # Find center point based on valid coordinates
107
+ valid_coords = [(loc.get("lat", 0), loc.get("lon", 0))
108
+ for loc in habitat_locations
109
+ if loc.get("lat", 0) != 0 or loc.get("lon", 0) != 0]
110
+
111
+ if valid_coords:
112
+ # Calculate the average of the coordinates
113
+ avg_lat = sum(lat for lat, _ in valid_coords) / len(valid_coords)
114
+ avg_lon = sum(lon for _, lon in valid_coords) / len(valid_coords)
115
+ # Create map centered on the average coordinates
116
+ m = folium.Map(location=[avg_lat, avg_lon], zoom_start=3)
117
+ else:
118
+ # Default world map if no valid coordinates
119
+ m = folium.Map(location=[20, 0], zoom_start=2)
120
+
121
+ # Add markers for each habitat location
122
+ for location in habitat_locations:
123
+ name = location.get("name", "Unknown")
124
+ lat = location.get("lat", 0)
125
+ lon = location.get("lon", 0)
126
+ description = location.get("description", "No description available")
127
+
128
+ # Skip invalid coordinates
129
+ if lat == 0 and lon == 0:
130
+ continue
131
+
132
+ # Add marker
133
+ folium.Marker(
134
+ location=[lat, lon],
135
+ popup=folium.Popup(f"<b>{name}</b><br>{description}", max_width=300),
136
+ tooltip=name
137
+ ).add_to(m)
138
+
139
+ # Save map to HTML
140
+ map_html = m._repr_html_()
141
+ return map_html
142
+
143
+ def format_bird_info(raw_info):
144
+ """Improve the formatting of bird information"""
145
+ # Add proper line breaks between sections and ensure consistent heading levels
146
+ formatted = raw_info
147
+
148
+ # Fix heading levels (make all main sections h3)
149
+ formatted = re.sub(r'#+\s+NOT TYPICALLY FOUND IN TANZANIA',
150
+ '<div class="alert alert-warning"><strong>⚠️ NOT TYPICALLY FOUND IN TANZANIA</strong></div>',
151
+ formatted)
152
+
153
+ # Replace markdown headings with HTML headings for better control
154
+ formatted = re.sub(r'#+\s+(.*)', r'<h3>\1</h3>', formatted)
155
+
156
+ # Add paragraph tags for better spacing
157
+ formatted = re.sub(r'\n\*\s+(.*)', r'<p>• \1</p>', formatted)
158
+ formatted = re.sub(r'\n([^<\n].*)', r'<p>\1</p>', formatted)
159
+
160
+ # Remove any duplicate paragraph tags
161
+ formatted = formatted.replace('<p><p>', '<p>')
162
+ formatted = formatted.replace('</p></p>', '</p>')
163
+
164
+ return formatted
165
+
166
+ def get_bird_info(bird_name):
167
+ """Get detailed information about a bird using Groq API"""
168
+ clean_name = clean_bird_name(bird_name)
169
+
170
+ prompt = f"""
171
+ Provide detailed information about the {clean_name} bird, including:
172
+ 1. Physical characteristics and appearance
173
+ 2. Habitat and distribution
174
+ 3. Diet and behavior
175
+ 4. Migration patterns (emphasize if this pattern has changed in recent years due to climate change)
176
+ 5. Conservation status
177
+
178
+ If this bird is not commonly found in Tanzania, explicitly flag that this bird is "NOT TYPICALLY FOUND IN TANZANIA" at the beginning of your response and explain why its presence might be unusual.
179
+
180
+ Format your response in markdown for better readability.
181
+ """
182
+
183
+ try:
184
+ chat_completion = client.chat.completions.create(
185
+ messages=[
186
+ {
187
+ "role": "user",
188
+ "content": prompt,
189
+ }
190
+ ],
191
+ model="llama-3.3-70b-versatile",
192
+ )
193
+ return chat_completion.choices[0].message.content
194
+ except Exception as e:
195
+ return f"Error fetching information: {str(e)}"
196
+
197
+ def predict_and_get_info(img):
198
+ """Predict bird species and get detailed information"""
199
+ # Process the image
200
  img = PILImage.create(img)
201
+
202
+ # Get prediction
203
+ pred, pred_idx, probs = learn.predict(img)
204
+
205
+ # Get top 5 predictions (or all if less than 5)
206
+ num_classes = min(5, len(labels))
207
+ top_indices = probs.argsort(descending=True)[:num_classes]
208
+ top_probs = probs[top_indices]
209
+ top_labels = [labels[i] for i in top_indices]
210
+
211
+ # Format as dictionary with cleaned names for display
212
+ prediction_results = {clean_bird_name(top_labels[i]): float(top_probs[i]) for i in range(num_classes)}
213
+
214
+ # Get top prediction (original format for info retrieval)
215
+ top_bird = str(pred)
216
+ # Also keep a clean version for display
217
+ clean_top_bird = clean_bird_name(top_bird)
218
+
219
+ # Get habitat locations and create map
220
+ habitat_locations, is_in_tanzania = get_bird_habitat_map(top_bird)
221
+ habitat_map_html = create_habitat_map(habitat_locations)
222
+
223
+ # Get detailed information about the top predicted bird
224
+ bird_info = get_bird_info(top_bird)
225
+ formatted_info = format_bird_info(bird_info)
226
+
227
+ # Create combined info with map at the top and properly formatted information
228
+ custom_css = """
229
+ <style>
230
+ .bird-container {
231
+ font-family: Arial, sans-serif;
232
+ padding: 10px;
233
+ }
234
+ .map-container {
235
+ height: 400px;
236
+ width: 100%;
237
+ border: 1px solid #ddd;
238
+ border-radius: 8px;
239
+ overflow: hidden;
240
+ margin-bottom: 20px;
241
+ }
242
+ .info-container {
243
+ line-height: 1.6;
244
+ }
245
+ .info-container h3 {
246
+ margin-top: 20px;
247
+ margin-bottom: 10px;
248
+ color: #2c3e50;
249
+ border-bottom: 1px solid #eee;
250
+ padding-bottom: 5px;
251
+ }
252
+ .info-container p {
253
+ margin-bottom: 10px;
254
+ }
255
+ .alert {
256
+ padding: 10px;
257
+ margin-bottom: 15px;
258
+ border-radius: 4px;
259
+ }
260
+ .alert-warning {
261
+ background-color: #fcf8e3;
262
+ border: 1px solid #faebcc;
263
+ color: #8a6d3b;
264
+ }
265
+ </style>
266
+ """
267
+
268
+ combined_info = f"""
269
+ {custom_css}
270
+ <div class="bird-container">
271
+ <h2>Natural Habitat Map for {clean_top_bird}</h2>
272
+ <div class="map-container">
273
+ {habitat_map_html}
274
+ </div>
275
+
276
+ <div class="info-container">
277
+ <h2>Detailed Information</h2>
278
+ {formatted_info}
279
+ </div>
280
+ </div>
281
+ """
282
+
283
+ return prediction_results, combined_info, clean_top_bird
284
 
285
+ def follow_up_question(question, bird_name):
286
+ """Allow researchers to ask follow-up questions about the identified bird"""
287
+ if not question.strip() or not bird_name:
288
+ return "Please identify a bird first and ask a specific question about it."
289
+
290
+ prompt = f"""
291
+ The researcher is asking about the {bird_name} bird: "{question}"
292
+
293
+ Provide a detailed, scientific answer focusing on accurate ornithological information.
294
+ If the question relates to Tanzania or climate change impacts, emphasize those aspects in your response.
295
+
296
+ IMPORTANT: Do not repeat basic introductory information about the bird that would have already been provided in a general description.
297
+ Do not start your answer with phrases like "Introduction to the {bird_name}" or similar repetitive headers.
298
+ Directly answer the specific question asked.
299
+
300
+ Format your response in markdown for better readability.
301
+ """
302
+
303
+ try:
304
+ chat_completion = client.chat.completions.create(
305
+ messages=[
306
+ {
307
+ "role": "user",
308
+ "content": prompt,
309
+ }
310
+ ],
311
+ model="llama-3.3-70b-versatile",
312
+ )
313
+ return chat_completion.choices[0].message.content
314
+ except Exception as e:
315
+ return f"Error fetching information: {str(e)}"
316
+
317
+ # Create the Gradio interface
318
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
319
+ gr.Markdown("# Bird Species Identification for Researchers")
320
+ gr.Markdown("Upload an image to identify bird species and get detailed information relevant to research in Tanzania and climate change studies.")
321
+
322
+ # Store the current bird for context
323
+ current_bird = gr.State("")
324
+
325
+ # Main identification section
326
+ with gr.Row():
327
+ with gr.Column(scale=1):
328
+ input_image = gr.Image(type="pil", label="Upload Bird Image")
329
+ submit_btn = gr.Button("Identify Bird", variant="primary")
330
+
331
+ with gr.Column(scale=2):
332
+ prediction_output = gr.Label(label="Top 5 Predictions", num_top_classes=5)
333
+ bird_info_output = gr.HTML(label="Bird Information")
334
+
335
+ # Clear divider
336
+ gr.Markdown("---")
337
+
338
+ # Follow-up question section with improved UI
339
+ gr.Markdown("## Research Questions")
340
+
341
+ conversation_history = gr.Markdown("")
342
+
343
+ with gr.Row():
344
+ follow_up_input = gr.Textbox(
345
+ label="Ask a question about this bird",
346
+ placeholder="Example: How has climate change affected this bird's migration pattern?",
347
+ lines=2
348
+ )
349
+
350
+ with gr.Row():
351
+ follow_up_btn = gr.Button("Submit Question", variant="primary")
352
+ clear_btn = gr.Button("Clear Conversation")
353
+
354
+ # Set up event handlers
355
+ def process_image(img):
356
+ if img is None:
357
+ return None, "Please upload an image", "", ""
358
+
359
+ try:
360
+ pred_results, info, clean_bird_name = predict_and_get_info(img)
361
+ return pred_results, info, clean_bird_name, ""
362
+ except Exception as e:
363
+ return None, f"Error processing image: {str(e)}", "", ""
364
+
365
+ def update_conversation(question, bird_name, history):
366
+ if not question.strip():
367
+ return history
368
+
369
+ answer = follow_up_question(question, bird_name)
370
+
371
+ # Format the conversation with clear separation
372
+ new_exchange = f"""
373
+ ### Question:
374
+ {question}
375
+
376
+ ### Answer:
377
+ {answer}
378
+
379
+ ---
380
+ """
381
+ updated_history = new_exchange + history
382
+ return updated_history
383
+
384
+ def clear_conversation_history():
385
+ return ""
386
+
387
+ submit_btn.click(
388
+ process_image,
389
+ inputs=[input_image],
390
+ outputs=[prediction_output, bird_info_output, current_bird, conversation_history]
391
+ )
392
+
393
+ follow_up_btn.click(
394
+ update_conversation,
395
+ inputs=[follow_up_input, current_bird, conversation_history],
396
+ outputs=[conversation_history]
397
+ ).then(
398
+ lambda: "",
399
+ outputs=follow_up_input
400
+ )
401
+
402
+ clear_btn.click(
403
+ clear_conversation_history,
404
+ outputs=[conversation_history]
405
+ )
406
+
407
+ # Launch the app
408
+ app.launch(share=True)