ozzyonfire's picture
using resnet instead of efficientnet
03952c2
raw
history blame
3.37 kB
import torch
from datasets import load_dataset
import evaluate
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
import numpy as np
from torchvision import models, transforms
print("Cuda availability:", torch.cuda.is_available())
cuda = torch.device('cuda')
print("cuda: ", torch.cuda.get_device_name(device=cuda))
dataset = load_dataset("chriamue/bird-species-dataset")
model_name = "google/efficientnet-b2"
finetuned_model_name = "chriamue/bird-species-classifier"
#####
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
# preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
# model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
# labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
# Replace the EfficientNetImageProcessor with torchvision transforms
preprocessor = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Replace the EfficientNetForImageClassification with torchvision ResNet-50
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(labels))
training_args = TrainingArguments(
finetuned_model_name, remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=6,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy"
)
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels)
def transforms(examples):
pixel_values = [preprocessor(image, return_tensors="pt").pixel_values.squeeze(
0) for image in examples["image"]]
examples["pixel_values"] = pixel_values
return examples
image = dataset["train"][0]["image"]
# dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1500))
# dataset["validation"] = dataset["validation"].select(range(100))
# dataset["test"] = dataset["test"].select(range(100))
dataset = dataset.map(transforms, remove_columns=["image"], batched=True)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics,
)
train_results = trainer.train(resume_from_checkpoint=False)
print(trainer.evaluate())
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
trainer.save_model(".")
dummy_input = torch.randn(1, 3, 224, 224)
model = model.to('cpu')
output_onnx_path = 'model.onnx'
torch.onnx.export(model, dummy_input, output_onnx_path, opset_version=13)
inputs = preprocessor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(labels[predicted_label])