{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import pandas as pd\n", "import os\n", "from datasets import Dataset, Image, DatasetDict\n", "from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor\n", "from transformers import (\n", " AutoImageProcessor,\n", " AutoModelForImageClassification,\n", " TrainingArguments,\n", " Trainer,\n", " DefaultDataCollator,\n", ")\n", "import evaluate\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "file2obj = pd.read_csv(\"../data/processed/OM_file_to_obj.csv\")\n", "file2obj[\"image\"] = file2obj.apply(lambda x: os.path.join(\"..\", x[\"root\"], x[\"file\"]), axis=1)\n", "file2obj.rename(columns={\"obj_num\": \"label\"}, inplace=True)\n", "\n", "# Group by 'obj_num' and count occurrences\n", "obj_num_counts = file2obj[\"label\"].value_counts()\n", "\n", "# Filter rows where 'obj_num' appears more than twice\n", "file2obj_3 = file2obj[file2obj[\"label\"].isin(obj_num_counts[obj_num_counts > 2].index)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Form HF dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "59370086a1b64dc5842d9becd9019aad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Casting to class labels: 0%| | 0/25725 [00:00