import gradio as gr import json import time import os from google import genai from google.genai import types from google.genai import errors from bioclip import TreeOfLifeClassifier, Rank PROMPT_RETRYIES = 2 DEFAULT_PROMPT = """ Return bounding boxes and a description for each species in this image. Ensure you only return valid JSON. """.strip() # Initialize classifier outside of functions classifier = TreeOfLifeClassifier() def crop_image(image, gemini_bounding_box): """ Crop the image based on the bounding box coordinates. :param image: PIL Image object :param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000 :return: Cropped PIL Image """ width, height = image.size y_min, x_min, y_max, x_max = gemini_bounding_box # Convert normalized coordinates to pixel values left = int(x_min / 1000 * width) upper = int(y_min / 1000 * height) right = int(x_max / 1000 * width) lower = int(y_max / 1000 * height) # Crop and return the image return image.crop((left, upper, right, lower)) def predict_species(img): predictions = classifier.predict([img], Rank.SPECIES, k=1) return predictions[0] def make_crops(image, predictions_json_txt): """ Process predictions to crop images based on bounding boxes. :param image: PIL Image object :param predictions: str of JSON List of prediction dictionaries containing bounding boxes :return: List of cropped images """ cropped_images = [] try: predictions_json_txt predictions = json.loads(predictions_json_txt) except json.JSONDecodeError as e: print(str(e)) return [] # Return empty list if JSON parsing fails for prediction in predictions: if "box_2d" in prediction: gemini_bounding_box = prediction["box_2d"] # Crop the image using the bounding box try: cropped_image = crop_image(image, gemini_bounding_box) cropped_images.append(cropped_image) except Exception as e: print(f"Error cropping image: {e}") return cropped_images def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES): # Initialize the client with the provided API key client = genai.Client(api_key=api_key) generate_content_config = types.GenerateContentConfig( response_mime_type="application/json", ) while True: try: response = client.models.generate_content( model="gemini-2.5-pro-exp-03-25", contents=[prompt, pil_image], config=generate_content_config, ) print("Result", response.text) crop_images = make_crops( image=pil_image, predictions_json_txt=response.text ) # crop_images_with_labels = [(img, "bob") for img in crop_images] # For Gradio Gallery, you can add labels here if needed crop_images_with_labels = [] for img in crop_images: prediction = predict_species(img) label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}" crop_images_with_labels.append((img, label)) return response.text, crop_images_with_labels except errors.ServerError as e: tries -= 1 if tries == 0: raise e print(f"Retrying... {e}") time.sleep(5) # Define the Gradio interface with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo: gr.Markdown("# Image Analysis with Gemini 2.5 Pro + BioCLIP") with gr.Row(): with gr.Column(): gr.Markdown("## Upload an image and enter a prompt to get predictions") api_key_input = gr.Textbox( label="Gemini API Key", placeholder="Enter your Gemini API key here...", type="password", ) image_input = gr.Image(label="Upload an image", type="pil") gr.Markdown("The prompt below must request bounding boxes.") prompt_input = gr.TextArea( label="Enter your prompt", placeholder="Describe what you want to analyze...", value=DEFAULT_PROMPT, ) submit_btn = gr.Button("Analyze") with gr.Column(): gr.Markdown("## Gemini Results") output = gr.JSON(label="Predictions") gr.Markdown("## Cropped Images with BioCLIP Predictions") image_gallery = gr.Gallery(label="Images", show_label=True) submit_btn.click( fn=generate_content_str, inputs=[api_key_input, prompt_input, image_input], outputs=[output, image_gallery], ) # Launch the app if __name__ == "__main__": demo.launch()