import streamlit as st from PIL import Image import pandas as pd import io import os import requests from autogluon.multimodal import MultiModalPredictor from huggingface_hub import snapshot_download import logging import datetime import re # Configure logging log_filename = "model_predictions.log" logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s') # Set the page config st.set_page_config(page_title="Honey Bee Image Classification", layout="wide") @st.cache_resource def load_model(): repo_id = "Honey-Bee-Society/honeybee_ml_v1" local_dir = snapshot_download(repo_id) assets_path = os.path.join(local_dir, "assets.json") model_checkpoint = os.path.join(local_dir, "model.ckpt") if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint): raise FileNotFoundError("Required model files not found in the downloaded directory.") return MultiModalPredictor.load(local_dir) def resize_image_proportionally(image, max_size_mb=1): img_byte_array = io.BytesIO() image.save(img_byte_array, format='PNG') img_size = len(img_byte_array.getvalue()) / (1024 * 1024) if img_size > max_size_mb: scale_factor = (max_size_mb / img_size) ** 0.5 new_width = int(image.width * scale_factor) new_height = int(image.height * scale_factor) image = image.resize((new_width, new_height)) return image def predict_image(image, predictor): img_byte_array = io.BytesIO() image.save(img_byte_array, format='PNG') img_data = img_byte_array.getvalue() df = pd.DataFrame({"image": [img_data]}) probabilities = predictor.predict_proba(df, realtime=True) return probabilities def save_image(image, img_name, target_size_kb=500): processed_image_path = os.path.join("processed_images", img_name) if not os.path.exists("processed_images"): os.makedirs("processed_images") quality = 95 img_byte_array = io.BytesIO() while quality > 10: img_byte_array.seek(0) image.save(img_byte_array, format='JPEG', quality=quality) img_size_kb = len(img_byte_array.getvalue()) / 1024 if img_size_kb <= target_size_kb: break quality -= 5 with open(processed_image_path, "wb") as f: f.write(img_byte_array.getvalue()) return processed_image_path def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score): logging.info( f"Image Path: {image_path}, " f"Honeybee: {honeybee_score:.2f}%, " f"Bumblebee: {bumblebee_score:.2f}%, " f"Vespidae: {vespidae_score:.2f}%" ) def sanitize_filename(filename): safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename) return safe_filename def check_file_size(uploaded_file, max_size_mb=10): uploaded_file.seek(0, os.SEEK_END) file_size = uploaded_file.tell() / (1024 * 1024) uploaded_file.seek(0) if file_size > max_size_mb: st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.") return False return True def run_api(predictor): """ 'API mode' for this Streamlit app. Expects a query param ?api=1&image_url= Example usage: curl "https://YOUR-SPACE.hf.space/?api=1&image_url=" WARNING: You will still get HTML with embedded JSON. That's a Streamlit limitation. """ # Use st.query_params (not st.experimental_get_query_params) params = st.query_params image_url = params.get("image_url", [None])[0] # `query_params` returns dict of lists if not image_url: st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url="}) st.stop() # Download the image response = requests.get( image_url, headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"} ) if response.status_code != 200: st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"}) st.stop() image_bytes = response.content # Check file size (limit 10MB) image_size_mb = len(image_bytes) / (1024 * 1024) if image_size_mb > 10: st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."}) st.stop() # Convert to PIL try: image = Image.open(io.BytesIO(image_bytes)) except Exception as e: st.json({"error": f"Could not open image: {e}"}) st.stop() # Resize image = resize_image_proportionally(image) # Predict try: probabilities = predict_image(image, predictor) honeybee_score = float(probabilities[1].iloc[0]) * 100 bumblebee_score = float(probabilities[2].iloc[0]) * 100 vespidae_score = float(probabilities[3].iloc[0]) * 100 except Exception as e: st.json({"error": f"Prediction failed: {e}"}) st.stop() # Determine highest-scoring label highest_score = max(honeybee_score, bumblebee_score, vespidae_score) if highest_score < 80: prediction_label = "No bee detected (scores too low)." else: if honeybee_score == highest_score: prediction_label = "Honey Bee" elif bumblebee_score == highest_score: prediction_label = "Bumblebee" else: prediction_label = "Vespidae (wasp/hornet)" # Return results as JSON, but note that Streamlit wraps this in HTML st.json({ "honeybee_score": honeybee_score, "bumblebee_score": bumblebee_score, "vespidae_score": vespidae_score, "prediction_label": prediction_label }) # Stop execution so the normal UI won't render st.stop() def run_ui(predictor): st.title("Honey Bee Image Classification") uploaded_file = st.file_uploader( "Upload a photo of the suspected bee...", type=["png", "jpg", "jpeg"] ) with st.expander("ML Model Details"): st.write(""" We trained a MultiModalPredictor to classify bee images (Honey Bee, Bumblebee, or Vespidae). Accuracy is ~97.5% on our test set. """) if uploaded_file is not None: if check_file_size(uploaded_file): image = Image.open(uploaded_file) image = resize_image_proportionally(image) progress_bar = st.progress(0) try: probabilities = predict_image(image, predictor) progress_bar.progress(100) honeybee_score = float(probabilities[1].iloc[0]) * 100 bumblebee_score = float(probabilities[2].iloc[0]) * 100 vespidae_score = float(probabilities[3].iloc[0]) * 100 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") sanitized_filename = sanitize_filename(uploaded_file.name) img_name = f"processed_{sanitized_filename}_{timestamp}.jpg" image_path = save_image(image, img_name) log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score) highest_score = max(honeybee_score, bumblebee_score, vespidae_score) if highest_score < 80: st.warning("We are fairly confident there is no bee in this photo.") else: if honeybee_score == highest_score: st.success("Yes! This is a honey bee!") elif bumblebee_score == highest_score: st.info("Likely a bumblebee, not a honey bee.") else: st.info("Likely a wasp/hornet (vespidae).") except Exception as e: st.error(f"An error occurred: {e}") finally: progress_bar.empty() def main(): predictor = load_model() # Decide whether we are in 'API mode' or normal UI mode query_params = st.query_params # Replaces st.experimental_get_query_params if "api" in query_params: run_api(predictor) else: run_ui(predictor) if __name__ == '__main__': main()