Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import time | |
import os | |
from google import genai | |
from google.genai import types | |
from google.genai import errors | |
from bioclip import TreeOfLifeClassifier, Rank | |
PROMPT_RETRYIES = 2 | |
DEFAULT_PROMPT = """ | |
Return bounding boxes and a description for each species in this image. | |
Ensure you only return valid JSON. | |
""".strip() | |
# Initialize classifier outside of functions | |
classifier = TreeOfLifeClassifier() | |
def crop_image(image, gemini_bounding_box): | |
""" | |
Crop the image based on the bounding box coordinates. | |
:param image: PIL Image object | |
:param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000 | |
:return: Cropped PIL Image | |
""" | |
width, height = image.size | |
y_min, x_min, y_max, x_max = gemini_bounding_box | |
# Convert normalized coordinates to pixel values | |
left = int(x_min / 1000 * width) | |
upper = int(y_min / 1000 * height) | |
right = int(x_max / 1000 * width) | |
lower = int(y_max / 1000 * height) | |
# Crop and return the image | |
return image.crop((left, upper, right, lower)) | |
def predict_species(img): | |
predictions = classifier.predict([img], Rank.SPECIES, k=1) | |
return predictions[0] | |
def make_crops(image, predictions_json_txt): | |
""" | |
Process predictions to crop images based on bounding boxes. | |
:param image: PIL Image object | |
:param predictions: str of JSON List of prediction dictionaries containing bounding boxes | |
:return: List of cropped images | |
""" | |
cropped_images = [] | |
try: | |
predictions_json_txt | |
predictions = json.loads(predictions_json_txt) | |
except json.JSONDecodeError as e: | |
print(str(e)) | |
return [] # Return empty list if JSON parsing fails | |
for prediction in predictions: | |
if "box_2d" in prediction: | |
gemini_bounding_box = prediction["box_2d"] | |
# Crop the image using the bounding box | |
try: | |
cropped_image = crop_image(image, gemini_bounding_box) | |
cropped_images.append(cropped_image) | |
except Exception as e: | |
print(f"Error cropping image: {e}") | |
return cropped_images | |
def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES): | |
# Initialize the client with the provided API key | |
client = genai.Client(api_key=api_key) | |
generate_content_config = types.GenerateContentConfig( | |
response_mime_type="application/json", | |
) | |
while True: | |
try: | |
response = client.models.generate_content( | |
model="gemini-2.5-pro-exp-03-25", | |
contents=[prompt, pil_image], | |
config=generate_content_config, | |
) | |
print("Result", response.text) | |
crop_images = make_crops( | |
image=pil_image, predictions_json_txt=response.text | |
) | |
# crop_images_with_labels = [(img, "bob") for img in crop_images] # For Gradio Gallery, you can add labels here if needed | |
crop_images_with_labels = [] | |
for img in crop_images: | |
prediction = predict_species(img) | |
label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}" | |
crop_images_with_labels.append((img, label)) | |
return response.text, crop_images_with_labels | |
except errors.ServerError as e: | |
tries -= 1 | |
if tries == 0: | |
raise e | |
print(f"Retrying... {e}") | |
time.sleep(5) | |
# Define the Gradio interface | |
with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo: | |
gr.Markdown("# Image Analysis with Gemini 2.5 Pro + BioCLIP") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Upload an image and enter a prompt to get predictions") | |
api_key_input = gr.Textbox( | |
label="Gemini API Key", | |
placeholder="Enter your Gemini API key here...", | |
type="password", | |
) | |
image_input = gr.Image(label="Upload an image", type="pil") | |
gr.Markdown("The prompt below must request bounding boxes.") | |
prompt_input = gr.TextArea( | |
label="Enter your prompt", | |
placeholder="Describe what you want to analyze...", | |
value=DEFAULT_PROMPT, | |
) | |
submit_btn = gr.Button("Analyze") | |
with gr.Column(): | |
gr.Markdown("## Gemini Results") | |
output = gr.JSON(label="Predictions") | |
gr.Markdown("## Cropped Images with BioCLIP Predictions") | |
image_gallery = gr.Gallery(label="Images", show_label=True) | |
submit_btn.click( | |
fn=generate_content_str, | |
inputs=[api_key_input, prompt_input, image_input], | |
outputs=[output, image_gallery], | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |