File size: 6,106 Bytes
0d11696
 
 
 
471d95f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe0c1e0
 
 
 
861d6d7
fe0c1e0
 
 
 
 
 
 
471d95f
 
 
fe0c1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471d95f
 
 
 
 
 
 
 
 
 
 
 
 
 
fe0c1e0
 
471d95f
 
 
 
 
 
 
 
0d11696
471d95f
 
 
 
 
0d11696
fe0c1e0
 
0d11696
fe0c1e0
 
 
0d11696
 
fe0c1e0
0d11696
 
 
 
 
 
 
 
471d95f
0d11696
471d95f
fe0c1e0
0d11696
 
 
471d95f
 
0d11696
471d95f
0d11696
 
471d95f
fe0c1e0
 
 
0d11696
fe0c1e0
 
 
0d11696
fe0c1e0
 
 
0d11696
fe0c1e0
 
 
 
 
 
 
 
 
 
 
0d11696
 
471d95f
0d11696
 
 
 
 
 
fe0c1e0
0d11696
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
import timm

# URL for the Hugging Face checkpoint
CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"

# Class labels
all_classes = [
    'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
    'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
    'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
    'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
    'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
    'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
    'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
    'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
    'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
    'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
    'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
    'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
    'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
    'Stylophora', 'Tubastraea', 'Turbinaria'
]

# Example image paths
example_images = {
    "Acropora": "coral_images/Acropora_millepora.jpg",
    "Agaricia": "coral_images/Agaricia_agaricites.jpg",
    "Acropora aculeus": "coral_images/Acropora_aculeus.jpg",
    "Montipora": "coral_images/Montipora_patula.jpg",
    "Pocillopora": "coral_images/Pocillopora_acuta.jpg",
    "Porites": "coral_images/porities_lobata.jpg",
    "Favites": "coral_images/Favites_abdita.jpg",
    "Fungia": "coral_images/Fungia_concinna.jpg",
}

# Function to load the BeIT model
def load_model(model_name):
    print(f"Loading {model_name} model...")
    args = type('', (), {})()
    args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
    args.nb_classes = len(all_classes)
    args.drop_path = 0.1

    # Create model
    model = timm.create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        use_rel_pos_bias=True,
        use_abs_pos_emb=True,
    )

    # Load checkpoint from Hugging Face
    checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
    state_dict = checkpoint.get('model', checkpoint)
    filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
    model.load_state_dict(filtered_state_dict, strict=False)

    # Move model to CUDA if available
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    return model

# Preprocessing transforms
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Initialize model
model = load_model('beit')

def predict_label(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    input_tensor = preprocess(image).unsqueeze(0)
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    with torch.no_grad():
        outputs = model(input_tensor)
        predicted_class = torch.argmax(outputs, dim=1).item()

    return all_classes[predicted_class]

def draw_rectangle(image, x, y, size=224):
    """Draw a clear red rectangle with increased thickness."""
    image_pil = image.copy()  # Create a copy to avoid modifying the original image
    draw = ImageDraw.Draw(image_pil)
    x1, y1 = x, y
    x2, y2 = x + size, y + size
    draw.rectangle([x1, y1, x2, y2], outline="red", width=6)  # Increase the width for clarity
    return image_pil


def crop_image(image, x, y, size=224):
    image_np = np.array(image)
    h, w, _ = image_np.shape
    x = min(max(x, 0), w - size)
    y = min(max(y, 0), h - size)
    cropped = image_np[y:y+size, x:x+size]
    return Image.fromarray(cropped)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Coral Classification with BeIT Model")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
            x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
            y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
        with gr.Column():
            interactive_image = gr.Image(label="Interactive Image")
            cropped_image = gr.Image(label="Cropped Patch")
            label_output = gr.Textbox(label="Predicted Label")
    
    # Crop and Predict buttons
    crop_button = gr.Button("Crop")
    predict_button = gr.Button("Predict")

    # Example table
    def load_example(example_path):
        return Image.open(example_path).convert("RGB")

        # Generate table of examples
    with gr.Row():
        gr.Markdown("### Example Images for Quick Testing")

    with gr.Row():
        for genus, path in example_images.items():
            with gr.Column():
                thumbnail = gr.Image(value=path, interactive=False, label=genus)
                select_button = gr.Button(value=f"Select {genus}")
                select_button.click(fn=lambda p=path: load_example(p), inputs=None, outputs=image_input)

    # Button functionality
    crop_button.click(fn=lambda img, x, y: (draw_rectangle(img, x, y), crop_image(img, x, y)), 
                      inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
    predict_button.click(fn=predict_label, inputs=cropped_image, outputs=label_output)

    def update_sliders(image):
        if image:
            width, height = image.size
            return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
        return gr.update(), gr.update()

    image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])

# demo.launch()
demo.launch(server_name="0.0.0.0", server_port=7860)