Spaces:
Sleeping
Sleeping
File size: 6,106 Bytes
0d11696 471d95f fe0c1e0 861d6d7 fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f 0d11696 471d95f 0d11696 fe0c1e0 0d11696 fe0c1e0 0d11696 fe0c1e0 0d11696 471d95f 0d11696 471d95f fe0c1e0 0d11696 471d95f 0d11696 471d95f 0d11696 471d95f fe0c1e0 0d11696 fe0c1e0 0d11696 fe0c1e0 0d11696 fe0c1e0 0d11696 471d95f 0d11696 fe0c1e0 0d11696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
import timm
# URL for the Hugging Face checkpoint
CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"
# Class labels
all_classes = [
'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
'Stylophora', 'Tubastraea', 'Turbinaria'
]
# Example image paths
example_images = {
"Acropora": "coral_images/Acropora_millepora.jpg",
"Agaricia": "coral_images/Agaricia_agaricites.jpg",
"Acropora aculeus": "coral_images/Acropora_aculeus.jpg",
"Montipora": "coral_images/Montipora_patula.jpg",
"Pocillopora": "coral_images/Pocillopora_acuta.jpg",
"Porites": "coral_images/porities_lobata.jpg",
"Favites": "coral_images/Favites_abdita.jpg",
"Fungia": "coral_images/Fungia_concinna.jpg",
}
# Function to load the BeIT model
def load_model(model_name):
print(f"Loading {model_name} model...")
args = type('', (), {})()
args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
args.nb_classes = len(all_classes)
args.drop_path = 0.1
# Create model
model = timm.create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
use_rel_pos_bias=True,
use_abs_pos_emb=True,
)
# Load checkpoint from Hugging Face
checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
state_dict = checkpoint.get('model', checkpoint)
filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
model.load_state_dict(filtered_state_dict, strict=False)
# Move model to CUDA if available
model.eval()
if torch.cuda.is_available():
model.cuda()
return model
# Preprocessing transforms
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Initialize model
model = load_model('beit')
def predict_label(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
input_tensor = preprocess(image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
return all_classes[predicted_class]
def draw_rectangle(image, x, y, size=224):
"""Draw a clear red rectangle with increased thickness."""
image_pil = image.copy() # Create a copy to avoid modifying the original image
draw = ImageDraw.Draw(image_pil)
x1, y1 = x, y
x2, y2 = x + size, y + size
draw.rectangle([x1, y1, x2, y2], outline="red", width=6) # Increase the width for clarity
return image_pil
def crop_image(image, x, y, size=224):
image_np = np.array(image)
h, w, _ = image_np.shape
x = min(max(x, 0), w - size)
y = min(max(y, 0), h - size)
cropped = image_np[y:y+size, x:x+size]
return Image.fromarray(cropped)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Coral Classification with BeIT Model")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
with gr.Column():
interactive_image = gr.Image(label="Interactive Image")
cropped_image = gr.Image(label="Cropped Patch")
label_output = gr.Textbox(label="Predicted Label")
# Crop and Predict buttons
crop_button = gr.Button("Crop")
predict_button = gr.Button("Predict")
# Example table
def load_example(example_path):
return Image.open(example_path).convert("RGB")
# Generate table of examples
with gr.Row():
gr.Markdown("### Example Images for Quick Testing")
with gr.Row():
for genus, path in example_images.items():
with gr.Column():
thumbnail = gr.Image(value=path, interactive=False, label=genus)
select_button = gr.Button(value=f"Select {genus}")
select_button.click(fn=lambda p=path: load_example(p), inputs=None, outputs=image_input)
# Button functionality
crop_button.click(fn=lambda img, x, y: (draw_rectangle(img, x, y), crop_image(img, x, y)),
inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
predict_button.click(fn=predict_label, inputs=cropped_image, outputs=label_output)
def update_sliders(image):
if image:
width, height = image.size
return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
return gr.update(), gr.update()
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
# demo.launch()
demo.launch(server_name="0.0.0.0", server_port=7860)
|