gemini-bioclip / app.py
johnbradley's picture
Improve text
49f5e8c
import gradio as gr
import json
import time
import os
from google import genai
from google.genai import types
from google.genai import errors
from bioclip import TreeOfLifeClassifier, Rank
PROMPT_RETRYIES = 2
DEFAULT_PROMPT = """
Return bounding boxes and a description for each species in this image.
Ensure you only return valid JSON.
""".strip()
# Initialize classifier outside of functions
classifier = TreeOfLifeClassifier()
def crop_image(image, gemini_bounding_box):
"""
Crop the image based on the bounding box coordinates.
:param image: PIL Image object
:param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000
:return: Cropped PIL Image
"""
width, height = image.size
y_min, x_min, y_max, x_max = gemini_bounding_box
# Convert normalized coordinates to pixel values
left = int(x_min / 1000 * width)
upper = int(y_min / 1000 * height)
right = int(x_max / 1000 * width)
lower = int(y_max / 1000 * height)
# Crop and return the image
return image.crop((left, upper, right, lower))
def predict_species(img):
predictions = classifier.predict([img], Rank.SPECIES, k=1)
return predictions[0]
def make_crops(image, predictions_json_txt):
"""
Process predictions to crop images based on bounding boxes.
:param image: PIL Image object
:param predictions: str of JSON List of prediction dictionaries containing bounding boxes
:return: List of cropped images
"""
cropped_images = []
try:
predictions_json_txt
predictions = json.loads(predictions_json_txt)
except json.JSONDecodeError as e:
print(str(e))
return [] # Return empty list if JSON parsing fails
for prediction in predictions:
if "box_2d" in prediction:
gemini_bounding_box = prediction["box_2d"]
# Crop the image using the bounding box
try:
cropped_image = crop_image(image, gemini_bounding_box)
cropped_images.append(cropped_image)
except Exception as e:
print(f"Error cropping image: {e}")
return cropped_images
def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES):
# Initialize the client with the provided API key
client = genai.Client(api_key=api_key)
generate_content_config = types.GenerateContentConfig(
response_mime_type="application/json",
)
while True:
try:
response = client.models.generate_content(
model="gemini-2.5-pro-exp-03-25",
contents=[prompt, pil_image],
config=generate_content_config,
)
print("Result", response.text)
crop_images = make_crops(
image=pil_image, predictions_json_txt=response.text
)
# crop_images_with_labels = [(img, "bob") for img in crop_images] # For Gradio Gallery, you can add labels here if needed
crop_images_with_labels = []
for img in crop_images:
prediction = predict_species(img)
label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}"
crop_images_with_labels.append((img, label))
return response.text, crop_images_with_labels
except errors.ServerError as e:
tries -= 1
if tries == 0:
raise e
print(f"Retrying... {e}")
time.sleep(5)
# Define the Gradio interface
with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo:
gr.Markdown("# Image Analysis with Gemini 2.5 Pro + BioCLIP")
with gr.Row():
with gr.Column():
gr.Markdown("## Upload an image and enter a prompt to get predictions")
api_key_input = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key here...",
type="password",
)
image_input = gr.Image(label="Upload an image", type="pil")
gr.Markdown("The prompt below must request bounding boxes.")
prompt_input = gr.TextArea(
label="Enter your prompt",
placeholder="Describe what you want to analyze...",
value=DEFAULT_PROMPT,
)
submit_btn = gr.Button("Analyze")
with gr.Column():
gr.Markdown("## Gemini Results")
output = gr.JSON(label="Predictions")
gr.Markdown("## Cropped Images with BioCLIP Predictions")
image_gallery = gr.Gallery(label="Images", show_label=True)
submit_btn.click(
fn=generate_content_str,
inputs=[api_key_input, prompt_input, image_input],
outputs=[output, image_gallery],
)
# Launch the app
if __name__ == "__main__":
demo.launch()