|
# ๐ง ClipSegMultiClass |
|
|
|
Multiclass semantic segmentation using CLIP + CLIPSeg. |
|
Fine-tuned version of [`CIDAS/clipseg-rd64-refined`](https://huggingface.co/CIDAS/clipseg-rd64-refined) |
|
Supports multiple classes in a single forward pass. |
|
|
|
--- |
|
|
|
## ๐ฌ Model |
|
|
|
**Name:** [`BioMike/clipsegmulticlass_v1`](https://huggingface.co/BioMike/clipsegmulticlass_v1) |
|
**Repository:** [github.com/BioMikeUkr/clipsegmulticlass](https://github.com/BioMikeUkr/clipsegmulticlass) |
|
**Base:** `CIDAS/clipseg-rd64-refined` |
|
**Classes:** `["background", "Pig", "Horse", "Sheep"]` |
|
**Image Size:** 352ร352 |
|
**Trained on:** OpenImages segmentation subset (custom fruit/animal dataset) |
|
|
|
--- |
|
|
|
## ๐ Evaluation |
|
|
|
| Model | Precision | Recall | F1 Score | Accuracy | |
|
|-----------------------------|-----------|---------|----------|----------| |
|
| CIDAS/clipseg-rd64-refined | 0.5239 | 0.2114 | 0.2882 | 0.2665 | |
|
| BioMike/clipsegmulticlass_v1| 0.7460 | 0.5035 | 0.6009 | 0.6763 | |
|
|
|
--- |
|
|
|
## ๐ฎ Demo |
|
|
|
๐ Try it online: |
|
[Hugging Face Space ๐](https://huggingface.co/spaces/BioMike/clipsegmulticlass) |
|
|
|
--- |
|
|
|
## ๐ฆ Usage |
|
|
|
```python |
|
from PIL import Image |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from model import ClipSegMultiClassModel |
|
from config import ClipSegMultiClassConfig |
|
|
|
# Load model |
|
model = ClipSegMultiClassModel.from_pretrained("trained_clipseg_multiclass").to("cuda").eval() |
|
config = model.config # contains label2color |
|
|
|
# Load image |
|
image = Image.open("pigs.jpg").convert("RGB") |
|
|
|
# Run inference |
|
mask = model.predict(image) # shape: [1, H, W] |
|
|
|
# Visualize |
|
def visualize_mask(mask_tensor: torch.Tensor, label2color: dict): |
|
if mask_tensor.dim() == 3: |
|
mask_tensor = mask_tensor.squeeze(0) |
|
|
|
mask_np = mask_tensor.cpu().numpy().astype(np.uint8) # [H, W] |
|
h, w = mask_np.shape |
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
|
for class_idx, color in label2color.items(): |
|
color_mask[mask_np == class_idx] = color |
|
|
|
return color_mask |
|
|
|
color_mask = visualize_mask(mask, config.label2color) |
|
|
|
plt.imshow(color_mask) |
|
plt.axis("off") |
|
plt.title("Predicted Segmentation Mask") |
|
plt.show() |
|
|