File size: 6,194 Bytes
df8f91f
fe7ff55
 
 
 
 
 
 
 
0ff9fca
fe7ff55
 
 
 
 
 
 
 
 
 
2ec89b4
fe7ff55
2ec89b4
 
fe7ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8f91f
fe7ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8f91f
 
 
 
 
 
 
fe7ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8f91f
fe7ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8f91f
 
 
 
fe7ff55
df8f91f
fe7ff55
 
 
df8f91f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import random
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from torchvision.transforms import ColorJitter, functional as F
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
from datasets import load_dataset
import evaluate

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the models
original_model_id = "guimCC/segformer-v0-gta"
lora_model_id = "guimCC/segformer-v0-gta-cityscapes"

original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)

# Load the dataset and select the first 10 images
dataset = load_dataset("Chris1/cityscapes", split="validation")
sampled_dataset = dataset.select(range(10))  # Select the first 10 examples


# Define your custom image processor
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

# Initialize mIoU metric
metric = evaluate.load("mean_iou")

# Define id2label and processor if not already defined
id2label = {
    0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence', 5: 'pole',
    6: 'traffic light', 7: 'traffic sign', 8: 'vegetation', 9: 'terrain',
    10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck', 15: 'bus',
    16: 'train', 17: 'motorcycle', 18: 'bicycle', 19: 'ignore'
}
processor = SegformerImageProcessor()

# Cityscapes color palette
palette = np.array([
    [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
    [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
    [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
    [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0]
])

def handle_grayscale_image(image):
    np_image = np.array(image)
    if np_image.ndim == 2:  # Grayscale image
        np_image = np.tile(np.expand_dims(np_image, -1), (1, 1, 3))
    return Image.fromarray(np_image)

def preprocess_image(image):
    image = handle_grayscale_image(image)
    image = jitter(image)  # Apply color jitter
    pixel_values = F.to_tensor(image).unsqueeze(0)  # Convert to tensor and add batch dimension
    return pixel_values.to(device)

def postprocess_predictions(logits):
    logits = logits.squeeze().detach().cpu().numpy()
    segmentation = np.argmax(logits, axis=0).astype(np.uint8)  # Convert to 8-bit integer
    return segmentation

def compute_miou(logits, labels):
    with torch.no_grad():
        logits_tensor = torch.from_numpy(logits)
        # Scale the logits to the size of the label
        logits_tensor = F.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        
        # Ensure the shapes of pred_labels and labels match
        if pred_labels.shape != labels.shape:
            labels = np.resize(labels, pred_labels.shape)
        
        pred_labels = [pred_labels]  # Wrap in a list
        labels = [labels]  # Wrap in a list
        
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=19,
            reduce_labels=processor.do_reduce_labels,
        )
        
        mean_iou = metrics.get('mean_iou', 0.0)
        
        if np.isnan(mean_iou):
            mean_iou = 0.0  # Handle NaN values gracefully
        
        return mean_iou
    
def apply_color_palette(segmentation):
    colored_segmentation = palette[segmentation]
    return Image.fromarray(colored_segmentation.astype(np.uint8))

def create_legend():
    # Define font and its size
    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except IOError:
        font = ImageFont.load_default()

    # Calculate legend dimensions
    num_classes = len(id2label)
    legend_height = 20 * ((num_classes + 1) // 2)  # Two items per row
    legend_width = 250

    # Create a blank image for the legend
    legend = Image.new("RGB", (legend_width, legend_height), (255, 255, 255))
    draw = ImageDraw.Draw(legend)

    # Draw each color and its label
    for i, (class_id, class_name) in enumerate(id2label.items()):
        color = tuple(palette[class_id])
        x = (i % 2) * 120
        y = (i // 2) * 20
        draw.rectangle([x, y, x + 20, y + 20], fill=color)
        draw.text((x + 30, y + 5), class_name, fill=(0, 0, 0), font=font)

    return legend

def inference(index, legend):
    """Run inference on the input image with both models."""
    image = sampled_dataset[index]['image']  # Fetch image from the sampled dataset
    pixel_values = preprocess_image(image)

    # Original model inference
    with torch.no_grad():
        original_outputs = original_model(pixel_values=pixel_values)
        original_segmentation = postprocess_predictions(original_outputs.logits)

    # LoRA model inference
    with torch.no_grad():
        lora_outputs = lora_model(pixel_values=pixel_values)
        lora_segmentation = postprocess_predictions(lora_outputs.logits)

    # Apply color palette
    original_segmentation_image = apply_color_palette(original_segmentation)
    lora_segmentation_image = apply_color_palette(lora_segmentation)

    # Return the original image, the segmentations, and mIoU
    return (
        image,
        original_segmentation_image,
        lora_segmentation_image,
    )

# Create a list of image options for the user to select from
image_options = [(f"Image {i}", i) for i in range(len(sampled_dataset))]

# Create the Gradio interface
iface = gr.Interface(
    fn=inference,
    inputs=[
        gr.Dropdown(label="Select Image", choices=image_options),
        gr.Image(type="pil", label="Legend", value=create_legend)
    ],
    outputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Image(type="pil", label="Original Model Prediction"),
        gr.Image(type="pil", label="LoRA Model Prediction"),

    ],
    live=True
)

# Launch the interface
iface.launch()