p1atdev's picture
Update README.md
eb6a0b5 verified
metadata
license: apache-2.0
base_model: google/siglip-so400m-patch14-384
tags:
  - generated_from_trainer
  - siglip
metrics:
  - accuracy
  - f1
model-index:
  - name: siglip-tagger-test-3
    results: []

siglip-tagger-test-3

This model is a fine-tuned version of google/siglip-so400m-patch14-384 on an unknown dataset. It achieves the following results on the evaluation set:

  • Loss: 692.4745
  • Accuracy: 0.3465
  • F1: 0.9969

Model description

This model is an experimental model that predicts danbooru tags of images.

Example

Use a pipeline

from transformers import pipeline

pipe = pipeline("image-classification", model="p1atdev/siglip-tagger-test-3", trust_remote_code=True)
pipe(
  "image.jpg", # takes str(path) or numpy array or PIL images as input
  threshold=0.5, #optional parameter defaults to 0
  return_scores = False #optional parameter defaults to False
)
  • threshold: confidence intervale, if it's specified, the pipeline will only return tags with a confidence >= threshold
  • return_scores: if specified the pipeline will return the labels and their confidences in a dictionary format.

Load model directly

from PIL import Image
import torch

from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
)

import numpy as np

MODEL_NAME = "p1atdev/siglip-tagger-test-3"

model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
)
model.eval()
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

image = Image.open("sample.jpg") # load your image

inputs = processor(image, return_tensors="pt").to(model.device, model.dtype)

logits = model(**inputs).logits.detach().cpu().float()[0]
logits = np.clip(logits, 0.0, 1.0)

results = {
    model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0
}
results = sorted(results.items(), key=lambda x: x[1], reverse=True)

for tag, score in results:
    print(f"{tag}: {score*100:.2f}%")

Intended uses & limitations

This model is for research use only and is not recommended for production.

Please use wd-v1-4-tagger series by SmilingWolf:

etc.

Training and evaluation data

High quality 5000 images from danbooru. They were shuffled and split into train:eval at 4500:500. (Same as p1atdev/siglip-tagger-test-2)

Name Description
Images count 5000
Supported tags 9517 general tags. Character and rating tags are not included. See all labels in config.json
Image rating 4000 for general and 1000 for sensitive,questionable,explicit
Copyright tags original only
Image score range (on search) min:10, max150

Training procedure

  • Loss function: AsymmetricLossOptimized (Asymmetric Loss)
    • gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0001
  • train_batch_size: 64
  • eval_batch_size: 32
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_steps: 10
  • num_epochs: 50

Training results

Training Loss Epoch Step Validation Loss Accuracy F1
1066.981 1.0 71 1873.5417 0.1412 0.9939
547.3158 2.0 142 934.3269 0.1904 0.9964
534.6942 3.0 213 814.0771 0.2170 0.9966
414.1278 4.0 284 774.0230 0.2398 0.9967
365.4994 5.0 355 751.2046 0.2459 0.9967
352.3663 6.0 426 735.6580 0.2610 0.9967
414.3976 7.0 497 723.2065 0.2684 0.9968
350.8201 8.0 568 714.0453 0.2788 0.9968
364.5016 9.0 639 706.5261 0.2890 0.9968
309.1184 10.0 710 700.7808 0.2933 0.9968
288.5186 11.0 781 695.7027 0.3008 0.9968
287.4452 12.0 852 691.5306 0.3037 0.9968
280.9088 13.0 923 688.8063 0.3084 0.9969
296.8389 14.0 994 686.1077 0.3132 0.9968
265.1467 15.0 1065 683.7382 0.3167 0.9969
268.5263 16.0 1136 682.1683 0.3206 0.9969
309.7871 17.0 1207 681.1995 0.3199 0.9969
307.6475 18.0 1278 680.1700 0.3230 0.9969
262.0677 19.0 1349 679.2177 0.3270 0.9969
275.3823 20.0 1420 678.9730 0.3294 0.9969
273.984 21.0 1491 678.6031 0.3318 0.9969
273.5361 22.0 1562 678.1285 0.3332 0.9969
279.6474 23.0 1633 678.4264 0.3348 0.9969
232.5045 24.0 1704 678.3773 0.3357 0.9969
269.621 25.0 1775 678.4922 0.3372 0.9969
289.8389 26.0 1846 679.0094 0.3397 0.9969
256.7373 27.0 1917 679.5618 0.3407 0.9969
262.3969 28.0 1988 680.1168 0.3414 0.9969
266.2439 29.0 2059 681.0101 0.3421 0.9969
247.7932 30.0 2130 681.9800 0.3422 0.9969
246.8083 31.0 2201 682.8550 0.3416 0.9969
270.827 32.0 2272 683.9250 0.3434 0.9969
256.4384 33.0 2343 685.0451 0.3448 0.9969
270.461 34.0 2414 686.2427 0.3439 0.9969
253.8104 35.0 2485 687.4274 0.3441 0.9969
265.532 36.0 2556 688.4856 0.3451 0.9969
249.1426 37.0 2627 689.5027 0.3457 0.9969
229.5651 38.0 2698 690.4455 0.3455 0.9969
251.9008 39.0 2769 691.2324 0.3463 0.9969
281.8228 40.0 2840 691.7993 0.3464 0.9969
242.5272 41.0 2911 692.1788 0.3465 0.9969
229.5605 42.0 2982 692.3799 0.3465 0.9969
245.0876 43.0 3053 692.4745 0.3465 0.9969
271.22 44.0 3124 692.5084 0.3465 0.9969
244.3045 45.0 3195 692.5108 0.3465 0.9969
243.9542 46.0 3266 692.5128 0.3465 0.9969
274.6664 47.0 3337 692.5095 0.3465 0.9969
231.1361 48.0 3408 692.5107 0.3465 0.9969
274.5513 49.0 3479 692.5108 0.3465 0.9969
316.0833 50.0 3550 692.5107 0.3465 0.9969

Framework versions

  • Transformers 4.37.2
  • Pytorch 2.1.2+cu118
  • Datasets 2.16.1
  • Tokenizers 0.15.0