In [1]:
from PIL import Image
import pandas as pd
import os
from datasets import Dataset, Image, DatasetDict
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
)
import evaluate
import numpy as np

### Load data

In [2]:
file2obj = pd.read_csv("../data/processed/OM_file_to_obj.csv")
file2obj["image"] = file2obj.apply(lambda x: os.path.join("..", x["root"], x["file"]), axis=1)
file2obj.rename(columns={"obj_num": "label"}, inplace=True)

# Group by 'obj_num' and count occurrences
obj_num_counts = file2obj["label"].value_counts()

# Filter rows where 'obj_num' appears more than twice
file2obj_3 = file2obj[file2obj["label"].isin(obj_num_counts[obj_num_counts > 2].index)]

### Form HF dataset

In [3]:
ds = Dataset.from_pandas(file2obj_3[["image", "label"]], preserve_index=False).cast_column(
    "image", Image()
)
ds = ds.class_encode_column("label")
trainval_test = ds.train_test_split(stratify_by_column="label", test_size=0.16)
train_val = trainval_test["train"].train_test_split(stratify_by_column="label", test_size=16 / 84)
ds = DatasetDict(
    {"train": train_val["train"], "valid": train_val["test"], "test": trainval_test["test"]}
)

Casting to class labels:   0%|          | 0/25725 [00:00<?, ? examples/s]

### Transform data

In [4]:
checkpoint = "google/efficientnet-b3"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)


normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])


def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples


ds = ds.with_transform(transforms)

### Set up model and metrics

In [5]:
labels = ds["train"].features["label"].names
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True,
)

data_collator = DefaultDataCollator()

accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

Some weights of EfficientNetForImageClassification were not initialized from the model checkpoint at google/efficientnet-b3 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1536]) in the checkpoint and torch.Size([3872, 1536]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3872]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Train model

In [10]:
training_args = TrainingArguments(
    output_dir="../models/test",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    # gradient_accumulation_steps=2,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],  # .select(range(100)),
    eval_dataset=ds["valid"],  # .select(range(100)),
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

trainer.train()



  0%|          | 0/274 [00:00<?, ?it/s]

{'loss': 8.0521, 'learning_rate': 1.785714285714286e-05, 'epoch': 0.04}
{'loss': 8.0927, 'learning_rate': 3.571428571428572e-05, 'epoch': 0.07}
{'loss': 8.1187, 'learning_rate': 4.959349593495935e-05, 'epoch': 0.11}
{'loss': 8.2335, 'learning_rate': 4.75609756097561e-05, 'epoch': 0.15}
{'loss': 8.2531, 'learning_rate': 4.5528455284552844e-05, 'epoch': 0.18}
{'loss': 8.2873, 'learning_rate': 4.3495934959349595e-05, 'epoch': 0.22}
{'loss': 8.2071, 'learning_rate': 4.146341463414634e-05, 'epoch': 0.26}
{'loss': 8.2287, 'learning_rate': 3.943089430894309e-05, 'epoch': 0.29}
{'loss': 8.1928, 'learning_rate': 3.739837398373984e-05, 'epoch': 0.33}
{'loss': 8.2053, 'learning_rate': 3.5365853658536584e-05, 'epoch': 0.36}
{'loss': 8.1621, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.4}
{'loss': 8.1731, 'learning_rate': 3.130081300813008e-05, 'epoch': 0.44}
{'loss': 8.1447, 'learning_rate': 2.926829268292683e-05, 'epoch': 0.47}
{'loss': 8.1161, 'learning_rate': 2.7235772357723577e-05, 'epo

  0%|          | 0/65 [00:00<?, ?it/s]

{'eval_loss': 8.02699089050293, 'eval_accuracy': 0.02575315840621963, 'eval_runtime': 25.2001, 'eval_samples_per_second': 163.333, 'eval_steps_per_second': 2.579, 'epoch': 1.0}
{'train_runtime': 236.2359, 'train_samples_per_second': 74.049, 'train_steps_per_second': 1.16, 'train_loss': 8.129460439194728, 'epoch': 1.0}


TrainOutput(global_step=274, training_loss=8.129460439194728, metrics={'train_runtime': 236.2359, 'train_samples_per_second': 74.049, 'train_steps_per_second': 1.16, 'train_loss': 8.129460439194728, 'epoch': 1.0})

### Evaluation

In [7]:
results = trainer.evaluate()
print(results)

test_results = trainer.predict(ds["test"].select(range(100)))

  0%|          | 0/7 [00:00<?, ?it/s]

{'eval_loss': 8.275933265686035, 'eval_accuracy': 0.0, 'eval_runtime': 0.6419, 'eval_samples_per_second': 155.791, 'eval_steps_per_second': 10.905, 'epoch': 0.57}


  0%|          | 0/7 [00:00<?, ?it/s]

In [12]:
model

EfficientNetForImageClassification(
  (efficientnet): EfficientNetModel(
    (embeddings): EfficientNetEmbeddings(
      (padding): ZeroPad2d((0, 1, 0, 1))
      (convolution): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=valid, bias=False)
      (batchnorm): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (activation): SiLU()
    )
    (encoder): EfficientNetEncoder(
      (blocks): ModuleList(
        (0): EfficientNetBlock(
          (depthwise_conv): EfficientNetDepthwiseLayer(
            (depthwise_conv_pad): ZeroPad2d((0, 1, 0, 1))
            (depthwise_conv): EfficientNetDepthwiseConv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=40, bias=False)
            (depthwise_norm): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
            (depthwise_act): SiLU()
          )
          (squeeze_excite): EfficientNetSqueezeExciteLayer(
            (squeeze): AdaptiveAvgPool2d(output