Spaces:
Sleeping
Sleeping
Commit
·
249c00e
1
Parent(s):
99ba249
Initial version
Browse files- app.py +143 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
from google.genai import errors
|
8 |
+
from bioclip import TreeOfLifeClassifier, Rank
|
9 |
+
|
10 |
+
|
11 |
+
PROMPT_RETRYIES = 2
|
12 |
+
DEFAULT_PROMPT = """
|
13 |
+
Return bounding boxes and a description for each species in this image.
|
14 |
+
Ensure you only return valid JSON.
|
15 |
+
""".strip()
|
16 |
+
|
17 |
+
# Initialize classifier outside of functions
|
18 |
+
classifier = TreeOfLifeClassifier()
|
19 |
+
|
20 |
+
|
21 |
+
def crop_image(image, gemini_bounding_box):
|
22 |
+
"""
|
23 |
+
Crop the image based on the bounding box coordinates.
|
24 |
+
|
25 |
+
:param image: PIL Image object
|
26 |
+
:param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000
|
27 |
+
:return: Cropped PIL Image
|
28 |
+
"""
|
29 |
+
width, height = image.size
|
30 |
+
y_min, x_min, y_max, x_max = gemini_bounding_box
|
31 |
+
|
32 |
+
# Convert normalized coordinates to pixel values
|
33 |
+
left = int(x_min / 1000 * width)
|
34 |
+
upper = int(y_min / 1000 * height)
|
35 |
+
right = int(x_max / 1000 * width)
|
36 |
+
lower = int(y_max / 1000 * height)
|
37 |
+
|
38 |
+
# Crop and return the image
|
39 |
+
return image.crop((left, upper, right, lower))
|
40 |
+
|
41 |
+
|
42 |
+
def predict_species(img):
|
43 |
+
predictions = classifier.predict([img], Rank.SPECIES, k=1)
|
44 |
+
return predictions[0]
|
45 |
+
|
46 |
+
|
47 |
+
def make_crops(image, predictions_json_txt):
|
48 |
+
"""
|
49 |
+
Process predictions to crop images based on bounding boxes.
|
50 |
+
|
51 |
+
:param image: PIL Image object
|
52 |
+
:param predictions: str of JSON List of prediction dictionaries containing bounding boxes
|
53 |
+
:return: List of cropped images
|
54 |
+
"""
|
55 |
+
cropped_images = []
|
56 |
+
try:
|
57 |
+
predictions_json_txt
|
58 |
+
predictions = json.loads(predictions_json_txt)
|
59 |
+
except json.JSONDecodeError as e:
|
60 |
+
print(str(e))
|
61 |
+
return [] # Return empty list if JSON parsing fails
|
62 |
+
|
63 |
+
for prediction in predictions:
|
64 |
+
if "box_2d" in prediction:
|
65 |
+
gemini_bounding_box = prediction["box_2d"]
|
66 |
+
# Crop the image using the bounding box
|
67 |
+
try:
|
68 |
+
cropped_image = crop_image(image, gemini_bounding_box)
|
69 |
+
cropped_images.append(cropped_image)
|
70 |
+
except Exception as e:
|
71 |
+
print(f"Error cropping image: {e}")
|
72 |
+
|
73 |
+
return cropped_images
|
74 |
+
|
75 |
+
|
76 |
+
def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES):
|
77 |
+
# Initialize the client with the provided API key
|
78 |
+
client = genai.Client(api_key=api_key)
|
79 |
+
generate_content_config = types.GenerateContentConfig(
|
80 |
+
response_mime_type="application/json",
|
81 |
+
)
|
82 |
+
|
83 |
+
while True:
|
84 |
+
try:
|
85 |
+
response = client.models.generate_content(
|
86 |
+
model="gemini-2.5-pro-exp-03-25",
|
87 |
+
contents=[prompt, pil_image],
|
88 |
+
config=generate_content_config,
|
89 |
+
)
|
90 |
+
print("Result", response.text)
|
91 |
+
crop_images = make_crops(
|
92 |
+
image=pil_image, predictions_json_txt=response.text
|
93 |
+
)
|
94 |
+
# crop_images_with_labels = [(img, "bob") for img in crop_images] # For Gradio Gallery, you can add labels here if needed
|
95 |
+
crop_images_with_labels = []
|
96 |
+
for img in crop_images:
|
97 |
+
prediction = predict_species(img)
|
98 |
+
label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}"
|
99 |
+
crop_images_with_labels.append((img, label))
|
100 |
+
return response.text, crop_images_with_labels
|
101 |
+
except errors.ServerError as e:
|
102 |
+
tries -= 1
|
103 |
+
if tries == 0:
|
104 |
+
raise e
|
105 |
+
print(f"Retrying... {e}")
|
106 |
+
time.sleep(5)
|
107 |
+
|
108 |
+
|
109 |
+
# Define the Gradio interface
|
110 |
+
with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo:
|
111 |
+
gr.Markdown("# Image Analysis with Gemini 2.5 Pro")
|
112 |
+
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column():
|
115 |
+
gr.Markdown("## Upload an image and enter a prompt to get predictions")
|
116 |
+
api_key_input = gr.Textbox(
|
117 |
+
label="Gemini API Key",
|
118 |
+
placeholder="Enter your Gemini API key here...",
|
119 |
+
type="password",
|
120 |
+
)
|
121 |
+
image_input = gr.Image(label="Upload an image", type="pil")
|
122 |
+
prompt_input = gr.TextArea(
|
123 |
+
label="Enter your prompt",
|
124 |
+
placeholder="Describe what you want to analyze...",
|
125 |
+
value=DEFAULT_PROMPT,
|
126 |
+
)
|
127 |
+
submit_btn = gr.Button("Analyze")
|
128 |
+
|
129 |
+
with gr.Column():
|
130 |
+
gr.Markdown("## Gemini Results")
|
131 |
+
output = gr.JSON(label="Predictions")
|
132 |
+
gr.Markdown("## Cropped Images with BioCLIP Predictions")
|
133 |
+
image_gallery = gr.Gallery(label="Images", show_label=True)
|
134 |
+
|
135 |
+
submit_btn.click(
|
136 |
+
fn=generate_content_str,
|
137 |
+
inputs=[api_key_input, prompt_input, image_input],
|
138 |
+
outputs=[output, image_gallery],
|
139 |
+
)
|
140 |
+
|
141 |
+
# Launch the app
|
142 |
+
if __name__ == "__main__":
|
143 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pybioclip
|
2 |
+
google-genai
|