bioclip-canopy / app.py
daviddao's picture
init
7cf32e7
raw
history blame
2.53 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
import open_clip
from datasets import Dataset
import os
# Set environment variable to work around OpenMP runtime issue
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Load the model and processor
model, processor = open_clip.create_model_from_pretrained('hf-hub:imageomics/bioclip')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load the dataset
embedding_path = "./data/embeddings_bioclip_False"
ds = Dataset.load_from_disk(embedding_path)
# Load FAISS indexes
cosine_faiss_path = os.path.join(embedding_path, "embeddings_cosine.faiss")
l2_faiss_path = os.path.join(embedding_path, "embeddings_l2.faiss")
ds.load_faiss_index("embeddings_cosine", cosine_faiss_path)
ds.load_faiss_index("embeddings_l2", l2_faiss_path)
def majority_vote(classes, scores=None):
if scores is None:
scores = np.ones_like(classes)
unique_classes, class_counts = np.unique(classes, return_counts=True)
class_weights = {cls: 0 for cls in unique_classes}
for cls, weight in zip(classes, scores):
class_weights[cls] += weight
majority_class = max(class_weights, key=class_weights.get)
return majority_class
def classify_example(example, index="embeddings_l2", k=10, vote_scores=True):
features = np.array(example["embeddings"], dtype=np.float32)
scores, nearest = ds.get_nearest_examples(index, features, k)
class_labels = [ds.features["label"].names[c] for c in nearest["label"]]
if vote_scores:
prediction = majority_vote(class_labels, scores)
else:
prediction = majority_vote(class_labels)
return prediction, class_labels, nearest["file"]
def embed_image(image: Image.Image):
processed_images = processor(image).unsqueeze(0)
with torch.no_grad():
embeddings = model.encode_image(processed_images.to(device))
return {"embeddings": embeddings.cpu()}
def predict(image):
embedding = embed_image(image)
prediction, class_labels, file_paths = classify_example(embedding)
return prediction, ", ".join(class_labels[:3]), ", ".join(file_paths[:3])
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Textbox(label="Top 3 Classes"),
gr.Textbox(label="Top 3 File Paths")
],
title="BioClip Image Classification",
description="Upload an image to get a prediction using the BioClip model."
)
iface.launch()