from transformers import Trainer, TrainingArguments from datasets import load_dataset from transformers import ViTForImageClassification, ViTFeatureExtractor # Carregar o dataset "beans" dataset = load_dataset("beans") # Carregar o modelo pré-treinado e definir o número de classes corretamente (3 classes para Beans) model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", num_labels=3 # Beans tem 3 classes ) feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") # Preprocessamento def preprocess_function(examples): inputs = feature_extractor(examples["image"], return_tensors="pt") # A chave correta no Beans é "image" inputs["labels"] = examples["labels"] # Certifique-se de que o rótulo está correto return inputs # Aplicando o preprocessamento ao dataset dataset = dataset.map(preprocess_function, batched=True) # Definir os parâmetros de treinamento training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=64, num_train_epochs=3, weight_decay=0.01, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], # No Beans, o conjunto de teste é chamado de "validation" ) # Treinar o modelo trainer.train() # Salvar o modelo e o feature extractor treinados model.save_pretrained("./computer-vision-beans") feature_extractor.save_pretrained("./computer-vision-beans")