Spaces:
Runtime error
Runtime error
File size: 6,194 Bytes
df8f91f fe7ff55 0ff9fca fe7ff55 2ec89b4 fe7ff55 2ec89b4 fe7ff55 df8f91f fe7ff55 df8f91f fe7ff55 df8f91f fe7ff55 df8f91f fe7ff55 df8f91f fe7ff55 df8f91f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import random
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from torchvision.transforms import ColorJitter, functional as F
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
from datasets import load_dataset
import evaluate
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the models
original_model_id = "guimCC/segformer-v0-gta"
lora_model_id = "guimCC/segformer-v0-gta-cityscapes"
original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)
# Load the dataset and select the first 10 images
dataset = load_dataset("Chris1/cityscapes", split="validation")
sampled_dataset = dataset.select(range(10)) # Select the first 10 examples
# Define your custom image processor
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
# Initialize mIoU metric
metric = evaluate.load("mean_iou")
# Define id2label and processor if not already defined
id2label = {
0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence', 5: 'pole',
6: 'traffic light', 7: 'traffic sign', 8: 'vegetation', 9: 'terrain',
10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck', 15: 'bus',
16: 'train', 17: 'motorcycle', 18: 'bicycle', 19: 'ignore'
}
processor = SegformerImageProcessor()
# Cityscapes color palette
palette = np.array([
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
[153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0]
])
def handle_grayscale_image(image):
np_image = np.array(image)
if np_image.ndim == 2: # Grayscale image
np_image = np.tile(np.expand_dims(np_image, -1), (1, 1, 3))
return Image.fromarray(np_image)
def preprocess_image(image):
image = handle_grayscale_image(image)
image = jitter(image) # Apply color jitter
pixel_values = F.to_tensor(image).unsqueeze(0) # Convert to tensor and add batch dimension
return pixel_values.to(device)
def postprocess_predictions(logits):
logits = logits.squeeze().detach().cpu().numpy()
segmentation = np.argmax(logits, axis=0).astype(np.uint8) # Convert to 8-bit integer
return segmentation
def compute_miou(logits, labels):
with torch.no_grad():
logits_tensor = torch.from_numpy(logits)
# Scale the logits to the size of the label
logits_tensor = F.interpolate(
logits_tensor,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
# Ensure the shapes of pred_labels and labels match
if pred_labels.shape != labels.shape:
labels = np.resize(labels, pred_labels.shape)
pred_labels = [pred_labels] # Wrap in a list
labels = [labels] # Wrap in a list
metrics = metric.compute(
predictions=pred_labels,
references=labels,
num_labels=len(id2label),
ignore_index=19,
reduce_labels=processor.do_reduce_labels,
)
mean_iou = metrics.get('mean_iou', 0.0)
if np.isnan(mean_iou):
mean_iou = 0.0 # Handle NaN values gracefully
return mean_iou
def apply_color_palette(segmentation):
colored_segmentation = palette[segmentation]
return Image.fromarray(colored_segmentation.astype(np.uint8))
def create_legend():
# Define font and its size
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
# Calculate legend dimensions
num_classes = len(id2label)
legend_height = 20 * ((num_classes + 1) // 2) # Two items per row
legend_width = 250
# Create a blank image for the legend
legend = Image.new("RGB", (legend_width, legend_height), (255, 255, 255))
draw = ImageDraw.Draw(legend)
# Draw each color and its label
for i, (class_id, class_name) in enumerate(id2label.items()):
color = tuple(palette[class_id])
x = (i % 2) * 120
y = (i // 2) * 20
draw.rectangle([x, y, x + 20, y + 20], fill=color)
draw.text((x + 30, y + 5), class_name, fill=(0, 0, 0), font=font)
return legend
def inference(index, legend):
"""Run inference on the input image with both models."""
image = sampled_dataset[index]['image'] # Fetch image from the sampled dataset
pixel_values = preprocess_image(image)
# Original model inference
with torch.no_grad():
original_outputs = original_model(pixel_values=pixel_values)
original_segmentation = postprocess_predictions(original_outputs.logits)
# LoRA model inference
with torch.no_grad():
lora_outputs = lora_model(pixel_values=pixel_values)
lora_segmentation = postprocess_predictions(lora_outputs.logits)
# Apply color palette
original_segmentation_image = apply_color_palette(original_segmentation)
lora_segmentation_image = apply_color_palette(lora_segmentation)
# Return the original image, the segmentations, and mIoU
return (
image,
original_segmentation_image,
lora_segmentation_image,
)
# Create a list of image options for the user to select from
image_options = [(f"Image {i}", i) for i in range(len(sampled_dataset))]
# Create the Gradio interface
iface = gr.Interface(
fn=inference,
inputs=[
gr.Dropdown(label="Select Image", choices=image_options),
gr.Image(type="pil", label="Legend", value=create_legend)
],
outputs=[
gr.Image(type="pil", label="Input Image"),
gr.Image(type="pil", label="Original Model Prediction"),
gr.Image(type="pil", label="LoRA Model Prediction"),
],
live=True
)
# Launch the interface
iface.launch()
|