{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging\n", "import os\n", "from pathlib import Path\n", "import click\n", "from dotenv import find_dotenv, load_dotenv\n", "\n", "from datasets import load_dataset, ClassLabel\n", "import numpy as np\n", "import wandb\n", "import yaml\n", "from transformers.trainer_callback import EarlyStoppingCallback\n", "from artifact_classification.utils import ConfigLoader\n", "from torchvision.transforms import (\n", " Compose,\n", " Normalize,\n", " ToTensor,\n", " CenterCrop,\n", " Resize,\n", ")\n", "from transformers import (\n", " AutoImageProcessor,\n", " AutoModelForImageClassification,\n", " TrainingArguments,\n", " Trainer,\n", " DefaultDataCollator,\n", " AutoModelForSequenceClassification,\n", " DataCollatorWithPadding,\n", " AutoTokenizer,\n", ")\n", "from sklearn.metrics import top_k_accuracy_score\n", "import evaluate" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating with:\n", "{'config': 'om3txt_name', 'dataset': 'james-burton/OrientalMuseum_min3-name-text', 'wandb_proj_name': 'OrientalMuesumText', 'model_base': 'microsoft/deberta-v3-base', 'problem_type': 'text'}\n", "\n", "\n", "{'config': 'om3txt_name', 'fast_dev_run': False, 'do_train': True, 'do_predict': True, 'batch_size': 16, 'model_base': 'microsoft/deberta-v3-base', 'output_root': 'models/', 'num_epochs': 100, 'early_stopping_patience': 5, 'grad_accumulation_steps': 1, 'seed': 42, 'logging_steps': 10, 'lr_scheduler': 'linear', 'warmup_ratio': 0, 'weight_decay': 0, 'device': 'cuda', 'num_workers': 1, 'resume_from_checkpoint': False, 'predict_batch_size': 16, 'save_total_limit': 1, 'lr': 5e-05, 'pytorch2_0': False, 'max_length': 512, 'text_column': 'description', 'fp16': True, 'dataset': 'james-burton/OrientalMuseum_min3-name-text', 'wandb_proj_name': 'OrientalMuesumText', 'problem_type': 'text'}\n", "\n" ] } ], "source": [ "config = \"om3txt_name\"\n", "\n", "# Training args\n", "args = ConfigLoader(config, \"../configs/train_configs.yaml\", \"../configs/train_default.yaml\")\n", "\n", "# # Load dataset, filter out na inputs and labels and encode labels (as label column can change)\n", "# dataset = load_dataset(args.dataset) # , download_mode=\"force_redownload\")\n", "# dataset = dataset.filter(lambda example: example[args.label_column] is not None)\n", "# if args.problem_type == \"text\":\n", "# dataset = dataset.filter(lambda example: example[args.text_column] is not None)\n", "# dataset = dataset.rename_column(args.label_column, \"label\")\n", "# if not isinstance(dataset[\"train\"].features[\"label\"], ClassLabel):\n", "# dataset = dataset.class_encode_column(\"label\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "testing om3_num om3_material om3_name om3txt_material om3txt_name om3-white_num om3-white_material om3-white_name om3-3Dwhite_num om3-3Dwhite_material om3-3Dwhite_name om3-3Dwhite-1frame_num om3-3Dwhite-1frame_material om3-3Dwhite-1frame_name om4_num om4_material om4_name om4txt_material om4txt_name om4-white_num om4-white_material om4-white_name om4-3Dwhite_num om4-3Dwhite_material om4-3Dwhite_name om4-3Dwhite-1frame_num om4-3Dwhite-1frame_material om4-3Dwhite-1frame_name om5_num om5_material om5_name om5txt_material om5txt_name om5-white_num om5-white_material om5-white_name om5-3Dwhite_num om5-3Dwhite_material om5-3Dwhite_name om5-3Dwhite-1frame_num om5-3Dwhite-1frame_material om5-3Dwhite-1frame_name om6_num om6_material om6_name om6txt_material om6txt_name om6-white_num om6-white_material om6-white_name om6-3Dwhite_num om6-3Dwhite_material om6-3Dwhite_name om6-3Dwhite-1frame_num om6-3Dwhite-1frame_material om6-3Dwhite-1frame_name om3-3DwhiteTVT_num om3-3DwhiteTVT_material om3-3DwhiteTVT_name\n" ] } ], "source": [ "import yaml\n", "\n", "with open(\"../configs/train_configs.yaml\", \"r\") as file:\n", " configs = list(yaml.safe_load_all(file))\n", "\n", "config_names = \" \".join([cfg[\"config\"] for cfg in configs])\n", "print(config_names)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'testing om3_material om3_name om3-white_material om3-white_name om3-3Dwhite_material om3-3Dwhite_name om3-3Dwhite-1frame_material om3-3Dwhite-1frame_name om4_material om4_name om4-white_material om4-white_name om4-3Dwhite_material om4-3Dwhite_name om4-3Dwhite-1frame_material om4-3Dwhite-1frame_name om5_material om5_name om5-white_material om5-white_name om5-3Dwhite_material om5-3Dwhite_name om5-3Dwhite-1frame_material om5-3Dwhite-1frame_name om6_material om6_name om6-white_material om6-white_name om6-3Dwhite_material om6-3Dwhite_name om6-3Dwhite-1frame_material om6-3Dwhite-1frame_name om3-3DwhiteTVT_material om3-3DwhiteTVT_name'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\" \".join(\n", " [cfg[\"config\"] for cfg in configs if not (\"txt\" in cfg[\"config\"] or \"num\" in cfg[\"config\"])]\n", ")\n", "# \" \".join([cfg[\"config\"] for cfg in configs if \"1frame\" in cfg[\"config\"]])" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "l2i = {\n", " \"Album Painting\": 0,\n", " \"Animal Figurine\": 1,\n", " \"Animal Mummy\": 2,\n", " \"Animal bone\": 3,\n", " \"Belt Hook\": 4,\n", " \"Blouse\": 5,\n", " \"Bolt\": 6,\n", " \"Box\": 7,\n", " \"Brush Pot\": 8,\n", " \"Cap\": 9,\n", " \"Case\": 10,\n", " \"Clay pipe (smoking)\": 11,\n", " \"Cosmetic and Medical Equipment and Implements\": 12,\n", " \"Cup And Saucer\": 13,\n", " \"DVDs\": 14,\n", " \"Dagger\": 15,\n", " \"Disc\": 16,\n", " \"Domestic Equipment and Utensils\": 17,\n", " \"Earring\": 18,\n", " \"Finger Ring\": 19,\n", " \"Funerary Cone\": 20,\n", " \"Funerary goods\": 21,\n", " \"Funerary money\": 22,\n", " \"Hanging\": 23,\n", " \"Heart Scarab\": 24,\n", " \"Human Figurine\": 25,\n", " \"Inkstick\": 26,\n", " \"Kite\": 27,\n", " \"Kohl Pot\": 28,\n", " \"Letter\": 29,\n", " \"Manuscript Page\": 30,\n", " \"Mat\": 31,\n", " \"Mica Painting\": 32,\n", " \"Miniature Painting\": 33,\n", " \"Mortar\": 34,\n", " \"Mummy Label\": 35,\n", " \"Oracle Bone\": 36,\n", " \"Ostraka\": 37,\n", " \"Palette\": 38,\n", " \"Panel\": 39,\n", " \"Part\": 40,\n", " \"Pendant\": 41,\n", " \"Pipe\": 42,\n", " \"Pith Painting\": 43,\n", " \"Plaque\": 44,\n", " \"Plate\": 45,\n", " \"Scarab Seal\": 46,\n", " \"Scarf\": 47,\n", " \"Screen\": 48,\n", " \"Seal\": 49,\n", " \"Slide\": 50,\n", " \"Stand\": 51,\n", " \"Thangka\": 52,\n", " \"Water Dropper\": 53,\n", " \"Water Pot\": 54,\n", " \"Woodblock Print\": 55,\n", " \"accessories\": 56,\n", " \"albums\": 57,\n", " \"amulets\": 58,\n", " \"animation cels\": 59,\n", " \"animation drawings\": 60,\n", " \"armor\": 61,\n", " \"arrowheads\": 62,\n", " \"axes: woodworking tools\": 63,\n", " \"badges\": 64,\n", " \"bags\": 65,\n", " \"bandages\": 66,\n", " \"baskets\": 67,\n", " \"beads\": 68,\n", " \"bells\": 69,\n", " \"belts\": 70,\n", " \"blades\": 71,\n", " \"books\": 72,\n", " \"bottles\": 73,\n", " \"bowls\": 74,\n", " \"boxes\": 75,\n", " \"bracelets\": 76,\n", " \"brick\": 77,\n", " \"brooches\": 78,\n", " \"brush washers\": 79,\n", " \"buckets\": 80,\n", " \"buckles\": 81,\n", " \"calligraphy\": 82,\n", " \"canopic jars\": 83,\n", " \"cards\": 84,\n", " \"carvings\": 85,\n", " \"chains\": 86,\n", " \"chessmen\": 87,\n", " \"chopsticks\": 88,\n", " \"claypipe\": 89,\n", " \"cloth\": 90,\n", " \"clothing\": 91,\n", " \"coats\": 92,\n", " \"coins\": 93,\n", " \"collar\": 94,\n", " \"compact discs\": 95,\n", " \"containers\": 96,\n", " \"coverings\": 97,\n", " \"covers\": 98,\n", " \"cups\": 99,\n", " \"deity figurine\": 100,\n", " \"diagrams\": 101,\n", " \"dishes\": 102,\n", " \"dolls\": 103,\n", " \"drawings\": 104,\n", " \"dresses\": 105,\n", " \"drums\": 106,\n", " \"earrings\": 107,\n", " \"embroidery\": 108,\n", " \"ensembles\": 109,\n", " \"envelopes\": 110,\n", " \"equipment for personal use: grooming, hygiene and health care\": 111,\n", " \"ewers\": 112,\n", " \"fans\": 113,\n", " \"figures\": 114,\n", " \"figurines\": 115,\n", " \"flags\": 116,\n", " \"flasks\": 117,\n", " \"furniture components\": 118,\n", " \"gaming counters\": 119,\n", " \"glassware\": 120,\n", " \"hairpins\": 121,\n", " \"handles\": 122,\n", " \"harnesses\": 123,\n", " \"hats\": 124,\n", " \"headdresses\": 125,\n", " \"heads\": 126,\n", " \"incense burners\": 127,\n", " \"inlays\": 128,\n", " \"jackets\": 129,\n", " \"jars\": 130,\n", " \"jewelry\": 131,\n", " \"juglets\": 132,\n", " \"jugs\": 133,\n", " \"keys\": 134,\n", " \"kimonos\": 135,\n", " \"knives\": 136,\n", " \"lamps\": 137,\n", " \"lanterns\": 138,\n", " \"lids\": 139,\n", " \"maces\": 140,\n", " \"masks\": 141,\n", " \"medals\": 142,\n", " \"mirrors\": 143,\n", " \"models\": 144,\n", " \"mounts\": 145,\n", " \"nails\": 146,\n", " \"necklaces\": 147,\n", " \"needles\": 148,\n", " \"netsukes\": 149,\n", " \"ornaments\": 150,\n", " \"pages\": 151,\n", " \"paintings\": 152,\n", " \"paper money\": 153,\n", " \"pendants\": 154,\n", " \"petticoats\": 155,\n", " \"photographs\": 156,\n", " \"pictures\": 157,\n", " \"pins\": 158,\n", " \"playing cards\": 159,\n", " \"poker\": 160,\n", " \"postage stamps\": 161,\n", " \"postcards\": 162,\n", " \"posters\": 163,\n", " \"pots\": 164,\n", " \"pottery\": 165,\n", " \"prints\": 166,\n", " \"puppets\": 167,\n", " \"purses\": 168,\n", " \"reliefs\": 169,\n", " \"rings\": 170,\n", " \"robes\": 171,\n", " \"rubbings\": 172,\n", " \"rugs\": 173,\n", " \"sandals\": 174,\n", " \"saris\": 175,\n", " \"sarongs\": 176,\n", " \"scabbards\": 177,\n", " \"scaraboids\": 178,\n", " \"scarabs\": 179,\n", " \"scrolls\": 180,\n", " \"seed\": 181,\n", " \"seppa\": 182,\n", " \"shadow puppets\": 183,\n", " \"shawls\": 184,\n", " \"shell\": 185,\n", " \"sherds\": 186,\n", " \"shields\": 187,\n", " \"shoes\": 188,\n", " \"sketches\": 189,\n", " \"skirts\": 190,\n", " \"snuff bottles\": 191,\n", " \"socks\": 192,\n", " \"spatulas\": 193,\n", " \"spoons\": 194,\n", " \"statues\": 195,\n", " \"statuettes\": 196,\n", " \"stelae\": 197,\n", " \"straps\": 198,\n", " \"studs\": 199,\n", " \"swords\": 200,\n", " \"tablets\": 201,\n", " \"tacks\": 202,\n", " \"tea bowls\": 203,\n", " \"teapots\": 204,\n", " \"tiles\": 205,\n", " \"tools\": 206,\n", " \"toys\": 207,\n", " \"trays\": 208,\n", " \"tubes\": 209,\n", " \"tweezers\": 210,\n", " \"underwear\": 211,\n", " \"unidentified\": 212,\n", " \"ushabti\": 213,\n", " \"utensils\": 214,\n", " \"vases\": 215,\n", " \"vessels\": 216,\n", " \"weight\": 217,\n", " \"weights\": 218,\n", " \"whorls\": 219,\n", " \"wood blocks\": 220,\n", "}" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "import json" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# json dump\n", "with open(\"l2i.json\", \"w\") as f:\n", " json.dump({str(v): k for k, v in l2i.items()}, f)\n", "# {str(v): k for k, v in l2i.items()}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoConfig" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "005c080fdcf141acaa30ba191a8c8f3c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/10.9k [00:00