{ "cells": [ { "cell_type": "code", "execution_count": 12, "id": "45a0207c-925b-4680-8e97-a10960ade737", "metadata": { "execution": { "iopub.execute_input": "2025-04-08T10:24:11.581467Z", "iopub.status.busy": "2025-04-08T10:24:11.580094Z", "iopub.status.idle": "2025-04-08T10:24:15.120464Z", "shell.execute_reply": "2025-04-08T10:24:15.119371Z", "shell.execute_reply.started": "2025-04-08T10:24:11.581416Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)\n", "Requirement already satisfied: six>=1.10 in /usr/lib/python3/dist-packages (from kaggle) (1.16.0)\n", "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2023.7.22)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", "Requirement already satisfied: requests in /home/jupyter/.local/lib/python3.10/site-packages (from kaggle) (2.32.3)\n", "Requirement already satisfied: tqdm in /home/jupyter/.local/lib/python3.10/site-packages (from kaggle) (4.67.1)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.1)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.26.16)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.0.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (2.0.12)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.4)\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "# pip install kaggle" ] }, { "cell_type": "code", "execution_count": 1, "id": "ec6b0b5e-0b4a-414b-8ce2-e5d77d22b0e6", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:10:38.319843Z", "iopub.status.busy": "2025-04-10T10:10:38.319090Z", "iopub.status.idle": "2025-04-10T10:11:56.639570Z", "shell.execute_reply": "2025-04-10T10:11:56.638676Z", "shell.execute_reply.started": "2025-04-10T10:10:38.319807Z" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jupyter/.local/lib/python3.10/site-packages/transformers/utils/hub.py:105: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.10/dist-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n", " warn(\n", "2025-04-10 10:11:09.359041: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-04-10 10:11:13.083242: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1744279874.875999 3402 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1744279875.404529 3402 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "W0000 00:00:1744279879.393200 3402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1744279879.393256 3402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1744279879.393259 3402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1744279879.393262 3402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "2025-04-10 10:11:19.630625: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n", " warnings.warn(_BETA_TRANSFORMS_WARNING)\n", "/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n", " warnings.warn(_BETA_TRANSFORMS_WARNING)\n" ] } ], "source": [ "from transformers import pipeline\n", "import json\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from transformers import DistilBertTokenizer\n", "from tqdm import tqdm\n", "import re\n", "from datasets import Dataset\n", "from transformers import AutoModelForSequenceClassification\n", "import torch\n", "import numpy as np\n", "from typing import Dict\n", "from transformers import AutoModel\n", "from torch.nn import BCEWithLogitsLoss\n", "from typing import List\n", "from transformers import TrainingArguments, Trainer\n", "from collections import defaultdict\n", "\n", "from transformers import __version__ as transformers_version" ] }, { "cell_type": "code", "execution_count": 2, "id": "053ad8a2-e87c-4bd2-bc57-273d200c97a4", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:11:56.641992Z", "iopub.status.busy": "2025-04-10T10:11:56.641095Z", "iopub.status.idle": "2025-04-10T10:11:56.745161Z", "shell.execute_reply": "2025-04-10T10:11:56.744277Z", "shell.execute_reply.started": "2025-04-10T10:11:56.641954Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)\n", "model = \"distilbert-base-cased\"" ] }, { "cell_type": "markdown", "id": "f92d06f0-00c2-4784-967b-37904078db59", "metadata": {}, "source": [ "***Скачивание и обработка данных***" ] }, { "cell_type": "code", "execution_count": 3, "id": "28489e2a-e3fc-43db-ab6c-52fd937453b1", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:11:56.747131Z", "iopub.status.busy": "2025-04-10T10:11:56.746233Z", "iopub.status.idle": "2025-04-10T10:11:59.650148Z", "shell.execute_reply": "2025-04-10T10:11:59.649268Z", "shell.execute_reply.started": "2025-04-10T10:11:56.747093Z" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3402/637306215.py:8: DeprecationWarning: load_dataset is deprecated and will be removed in future version.\n", " df = kagglehub.load_dataset(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index(['author', 'day', 'id', 'link', 'month', 'summary', 'tag', 'title',\n", " 'year'],\n", " dtype='object')\n" ] } ], "source": [ "import kagglehub\n", "from kagglehub import KaggleDatasetAdapter\n", "\n", "# Set the path to the file you'd like to load\n", "file_path = \"arxivData.json\"\n", "\n", "# Load the latest version\n", "df = kagglehub.load_dataset(\n", " KaggleDatasetAdapter.PANDAS,\n", " \"neelshah18/arxivdataset\",\n", " file_path,\n", ")\n", "\n", "print(df.head)\n", "print(df.columns)" ] }, { "cell_type": "code", "execution_count": 4, "id": "bb24e6b3-09cb-4498-a387-edb57d17e11a", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:11:59.652894Z", "iopub.status.busy": "2025-04-10T10:11:59.652256Z", "iopub.status.idle": "2025-04-10T10:11:59.727485Z", "shell.execute_reply": "2025-04-10T10:11:59.726475Z", "shell.execute_reply.started": "2025-04-10T10:11:59.652852Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "155\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tagtopiccategory
0cs.AIArtificial IntelligenceComputer Science
1cs.ARHardware ArchitectureComputer Science
2cs.CCComputational ComplexityComputer Science
3cs.CEComputational Engineering, Finance, and ScienceComputer Science
4cs.CGComputational GeometryComputer Science
\n", "
" ], "text/plain": [ " tag topic category\n", "0 cs.AI Artificial Intelligence Computer Science\n", "1 cs.AR Hardware Architecture Computer Science\n", "2 cs.CC Computational Complexity Computer Science\n", "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n", "4 cs.CG Computational Geometry Computer Science" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n", "print(len(arxiv_topics_df))\n", "arxiv_topics_df.head(5)" ] }, { "cell_type": "code", "execution_count": 5, "id": "ec103ec9-db89-4b69-999a-d986a38a1a51", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:11:59.730334Z", "iopub.status.busy": "2025-04-10T10:11:59.728675Z", "iopub.status.idle": "2025-04-10T10:11:59.746729Z", "shell.execute_reply": "2025-04-10T10:11:59.745830Z", "shell.execute_reply.started": "2025-04-10T10:11:59.730286Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['cs.AI' 'cs.AR' 'cs.CC' 'cs.CE' 'cs.CG' 'cs.CL' 'cs.CR' 'cs.CV' 'cs.CY'\n", " 'cs.DB' 'cs.DC' 'cs.DL' 'cs.DM' 'cs.DS' 'cs.ET' 'cs.FL' 'cs.GL' 'cs.GR'\n", " 'cs.GT' 'cs.HC' 'cs.IR' 'cs.IT' 'cs.LG' 'cs.LO' 'cs.MA' 'cs.MM' 'cs.MS'\n", " 'cs.NA' 'cs.NE' 'cs.NI' 'cs.OH' 'cs.OS' 'cs.PF' 'cs.PL' 'cs.RO' 'cs.SC'\n", " 'cs.SD' 'cs.SE' 'cs.SI' 'cs.SY' 'econ.EM' 'econ.GN' 'econ.TH' 'eess.AS'\n", " 'eess.IV' 'eess.SP' 'eess.SY' 'math.AC' 'math.AG' 'math.AP' 'math.AT'\n", " 'math.CA' 'math.CO' 'math.CT' 'math.CV' 'math.DG' 'math.DS' 'math.FA'\n", " 'math.GM' 'math.GN' 'math.GR' 'math.GT' 'math.HO' 'math.IT' 'math.KT'\n", " 'math.LO' 'math.MG' 'math.MP' 'math.NA' 'math.NT' 'math.OA' 'math.OC'\n", " 'math.PR' 'math.QA' 'math.RA' 'math.RT' 'math.SG' 'math.SP' 'math.ST'\n", " 'astro-ph.CO' 'astro-ph.EP' 'astro-ph.GA' 'astro-ph.HE' 'astro-ph.IM'\n", " 'astro-ph.SR' 'cond-mat.dis-nn' 'cond-mat.mes-hall' 'cond-mat.mtrl-sci'\n", " 'cond-mat.other' 'cond-mat.quant-gas' 'cond-mat.soft'\n", " 'cond-mat.stat-mech' 'cond-mat.str-el' 'cond-mat.supr-con' 'gr-qc'\n", " 'hep-ex' 'hep-lat' 'hep-ph' 'hep-th' 'math-ph' 'nlin.AO' 'nlin.CD'\n", " 'nlin.CG' 'nlin.PS' 'nlin.SI' 'nucl-ex' 'nucl-th' 'physics.acc-ph'\n", " 'physics.ao-ph' 'physics.app-ph' 'physics.atm-clus' 'physics.atom-ph'\n", " 'physics.bio-ph' 'physics.chem-ph' 'physics.class-ph' 'physics.comp-ph'\n", " 'physics.data-an' 'physics.ed-ph' 'physics.flu-dyn' 'physics.gen-ph'\n", " 'physics.geo-ph' 'physics.hist-ph' 'physics.ins-det' 'physics.med-ph'\n", " 'physics.optics' 'physics.plasm-ph' 'physics.pop-ph' 'physics.soc-ph'\n", " 'physics.space-ph' 'quant-ph' 'q-bio.BM' 'q-bio.CB' 'q-bio.GN' 'q-bio.MN'\n", " 'q-bio.NC' 'q-bio.OT' 'q-bio.PE' 'q-bio.QM' 'q-bio.SC' 'q-bio.TO'\n", " 'q-fin.CP' 'q-fin.EC' 'q-fin.GN' 'q-fin.MF' 'q-fin.PM' 'q-fin.PR'\n", " 'q-fin.RM' 'q-fin.ST' 'q-fin.TR' 'stat.AP' 'stat.CO' 'stat.ME' 'stat.ML'\n", " 'stat.OT' 'stat.TH']\n" ] } ], "source": [ "print(arxiv_topics_df['tag'].unique())" ] }, { "cell_type": "code", "execution_count": 6, "id": "37967cd0-1473-43c4-82d5-3ac0ec4531b7", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:11:59.748954Z", "iopub.status.busy": "2025-04-10T10:11:59.747920Z", "iopub.status.idle": "2025-04-10T10:11:59.769861Z", "shell.execute_reply": "2025-04-10T10:11:59.769155Z", "shell.execute_reply.started": "2025-04-10T10:11:59.748916Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Computer Science' 'Economics'\n", " 'Electrical Engineering and Systems Science' 'Mathematics' 'Physics'\n", " 'Quantitative Biology' 'Quantitative Finance' 'Statistics']\n" ] } ], "source": [ "category_to_ind = {\n", " 'Computer Science': 0,\n", " 'Economics': 1,\n", " 'Electrical Engineering and Systems Science': 2,\n", " 'Mathematics': 3,\n", " 'Physics': 4,\n", " 'Quantitative Biology': 5,\n", " 'Quantitative Finance': 6,\n", " 'Statistics': 7\n", "}\n", "\n", "ind_to_category = {}\n", "for k, v in category_to_ind.items(): \n", " ind_to_category[v] = k\n", "\n", "term_to_category = {}\n", "for i in range(len(arxiv_topics_df)):\n", " term_to_category[arxiv_topics_df['tag'][i]] = arxiv_topics_df['category'][i]\n", "\n", "\n", "print(arxiv_topics_df['category'].unique())" ] }, { "cell_type": "code", "execution_count": 9, "id": "f31c1f41-13e2-473f-8b53-1668d12e52fc", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:21:54.926125Z", "iopub.status.busy": "2025-04-10T10:21:54.924983Z", "iopub.status.idle": "2025-04-10T10:21:54.960251Z", "shell.execute_reply": "2025-04-10T10:21:54.959449Z", "shell.execute_reply.started": "2025-04-10T10:21:54.926077Z" }, "tags": [] }, "outputs": [], "source": [ "import csv\n", "\n", "with open('ind_to_category.csv', mode='w', newline='') as f:\n", " writer = csv.writer(f)\n", " writer.writerow(['key', 'value']) # header\n", " for key, value in ind_to_category.items():\n", " writer.writerow([key, value])" ] }, { "cell_type": "code", "execution_count": 10, "id": "8e908332-33e2-409a-ae39-51adb59545ba", "metadata": { "execution": { "iopub.execute_input": "2025-04-10T10:22:56.251945Z", "iopub.status.busy": "2025-04-10T10:22:56.250724Z", "iopub.status.idle": "2025-04-10T10:22:56.265748Z", "shell.execute_reply": "2025-04-10T10:22:56.265049Z", "shell.execute_reply.started": "2025-04-10T10:22:56.251906Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'0': 'Computer Science', '1': 'Economics', '2': 'Electrical Engineering and Systems Science', '3': 'Mathematics', '4': 'Physics', '5': 'Quantitative Biology', '6': 'Quantitative Finance', '7': 'Statistics'}\n" ] } ], "source": [ "ind_to_category_copy = {}\n", "with open('ind_to_category.csv', mode='r', newline='') as f:\n", " reader = csv.reader(f)\n", " next(reader) # skip header\n", " for key, value in reader:\n", " ind_to_category_copy[key] = value\n", "\n", "print(ind_to_category_copy)" ] }, { "cell_type": "code", "execution_count": 7, "id": "96fb8526-c464-4f0e-b6df-22c2d373fc65", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:03.512867Z", "iopub.status.busy": "2025-04-09T16:16:03.512371Z", "iopub.status.idle": "2025-04-09T16:16:06.859361Z", "shell.execute_reply": "2025-04-09T16:16:06.858440Z", "shell.execute_reply.started": "2025-04-09T16:16:03.512832Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'inp': 'Dual Recurrent Attention Units for Visual Question Answering@We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.', 'probs': array([0.86230617, 0.01579369, 0.01579369, 0.01579369, 0.01579369,\n", " 0.01579369, 0.01579369, 0.0429317 ])}\n" ] } ], "source": [ "import ast\n", "import scipy\n", "\n", "arxiv_data = []\n", "for i in range(len(df)):\n", " cur_elem = {} \n", " cur_elem['inp'] = df['title'][i] + '@' + df['summary'][i]\n", " probs = [0] * len(category_to_ind)\n", " parsed_tags = ast.literal_eval(df['tag'][i])\n", " total_tags = len(parsed_tags)\n", " for j in range(len(parsed_tags)):\n", " term = parsed_tags[j]['term']\n", " if term not in term_to_category:\n", " continue\n", " category = term_to_category[term]\n", " ind = category_to_ind[category]\n", " probs[ind] += 1\n", " probs = scipy.special.softmax(probs)\n", " cur_elem['probs'] = probs\n", " arxiv_data.append(cur_elem)\n", " \n", " \n", "print(arxiv_data[0])" ] }, { "cell_type": "code", "execution_count": 8, "id": "3740a1bb-d0d4-4da4-aa88-e26061bda9a8", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:06.860959Z", "iopub.status.busy": "2025-04-09T16:16:06.860456Z", "iopub.status.idle": "2025-04-09T16:16:06.908278Z", "shell.execute_reply": "2025-04-09T16:16:06.907361Z", "shell.execute_reply.started": "2025-04-09T16:16:06.860924Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}\n" ] } ], "source": [ "import ast\n", "parsed_list = ast.literal_eval(df['tag'][0])\n", "print(parsed_list[0])" ] }, { "cell_type": "code", "execution_count": 9, "id": "bee4d350-7919-4a89-a9fc-bbea804afe0a", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:06.910847Z", "iopub.status.busy": "2025-04-09T16:16:06.909427Z", "iopub.status.idle": "2025-04-09T16:16:06.941781Z", "shell.execute_reply": "2025-04-09T16:16:06.940919Z", "shell.execute_reply.started": "2025-04-09T16:16:06.910806Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41000\n" ] } ], "source": [ "print(len(arxiv_data))" ] }, { "cell_type": "code", "execution_count": 10, "id": "53ed96e7-eec9-4f1a-85fb-f04b5782ed92", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:06.945896Z", "iopub.status.busy": "2025-04-09T16:16:06.945287Z", "iopub.status.idle": "2025-04-09T16:16:08.000266Z", "shell.execute_reply": "2025-04-09T16:16:07.999408Z", "shell.execute_reply.started": "2025-04-09T16:16:06.945856Z" }, "tags": [] }, "outputs": [], "source": [ "from transformers import DistilBertTokenizer, DistilBertModel\n", "\n", "# Load pretrained tokenizer and model\n", "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "f9317c52-636c-4db9-ab02-299951dec41d", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:08.002380Z", "iopub.status.busy": "2025-04-09T16:16:08.001362Z", "iopub.status.idle": "2025-04-09T16:16:08.190707Z", "shell.execute_reply": "2025-04-09T16:16:08.189828Z", "shell.execute_reply.started": "2025-04-09T16:16:08.002341Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 512])\n", "torch.Size([1, 512])\n" ] } ], "source": [ "MAX_LENGTH = 512\n", "tokenized_sentence = tokenizer(arxiv_data[0]['inp'], padding=\"max_length\", truncation=True, max_length=MAX_LENGTH, return_tensors=\"pt\")\n", "print(tokenized_sentence['input_ids'].shape)\n", "print(tokenized_sentence['attention_mask'].shape)\n", "\n", "# max_inp_size = 0\n", "# avg_inp_size = 0\n", "# for i in range(len(arxiv_data)):\n", "# max_inp_size = max(max_inp_size, len(arxiv_data[i]['inp']))\n", "# avg_inp_size += len(arxiv_data[i]['inp'])\n", "\n", "# avg_inp_size /= len(arxiv_data)\n", "\n", "# print(avg_inp_size)\n", "# print(max_inp_size)" ] }, { "cell_type": "code", "execution_count": 12, "id": "002fa81d-8775-4ddf-b730-db21b7db89b2", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:08.193634Z", "iopub.status.busy": "2025-04-09T16:16:08.191762Z", "iopub.status.idle": "2025-04-09T16:16:08.211641Z", "shell.execute_reply": "2025-04-09T16:16:08.210798Z", "shell.execute_reply.started": "2025-04-09T16:16:08.193594Z" }, "tags": [] }, "outputs": [], "source": [ "from torch.utils.data import DataLoader, Dataset\n", "\n", "class ArxivDataset(Dataset):\n", " def __init__(self, split=\"train\"):\n", " if (split == \"train\"):\n", " self.data = arxiv_data[:39000]\n", " else:\n", " self.data = arxiv_data[-2000:]\n", " \n", " def __len__(self):\n", " return len(self.data)\n", " \n", " def __getitem__(self, idx):\n", " #обрабатываем элемент\n", " cur_elem = self.data[idx]\n", " \n", " # Tokenize input and label\n", " encoding = tokenizer(cur_elem[\"inp\"], padding=\"max_length\", truncation=True, max_length=MAX_LENGTH, return_tensors=\"pt\")\n", " labels = torch.FloatTensor(cur_elem[\"probs\"])\n", " \n", " # input_ids = encoding[\"input_ids\"].squeeze(0) # Remove batch dim\n", " # attention_mask = encoding[\"attention_mask\"].squeeze(0)\n", " input_ids = encoding[\"input_ids\"].squeeze(0) # Remove batch dim\n", " attention_mask = encoding[\"attention_mask\"].squeeze(0)\n", " \n", " \n", " return {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"labels\": labels}\n", "\n", "BATCH_SIZE = 10\n", "\n", "train_dataset = ArxivDataset(\"train\")\n", "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "\n", "val_dataset = ArxivDataset(\"test\")\n", "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)" ] }, { "cell_type": "markdown", "id": "04c52585-65ce-4151-856f-b0763460898e", "metadata": {}, "source": [ "***Разберемся с моделью***" ] }, { "cell_type": "code", "execution_count": 13, "id": "014e60b2-1684-4270-a81a-29cac4ea46de", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:08.216322Z", "iopub.status.busy": "2025-04-09T16:16:08.212808Z", "iopub.status.idle": "2025-04-09T16:16:18.518029Z", "shell.execute_reply": "2025-04-09T16:16:18.517118Z", "shell.execute_reply.started": "2025-04-09T16:16:08.216273Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 512])\n", "torch.Size([10, 512])\n", "torch.Size([10, 8])\n", "torch.Size([10, 512, 768])\n" ] } ], "source": [ "model = DistilBertModel.from_pretrained('distilbert-base-cased')\n", "model.eval()\n", "\n", "\n", "# output = model(train_dataset[0][\"input_ids\"], train_dataset[0][\"attention_mask\"])\n", "# config = model.config\n", "\n", "batch = next(iter(train_dataloader))\n", "\n", "print(batch['input_ids'].shape)\n", "print(batch['attention_mask'].shape)\n", "print(batch['labels'].shape)\n", "print(model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']).last_hidden_state.shape)" ] }, { "cell_type": "code", "execution_count": 14, "id": "cff821d0-806c-4c3b-b6d3-2b99291e65a0", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:18.519553Z", "iopub.status.busy": "2025-04-09T16:16:18.519046Z", "iopub.status.idle": "2025-04-09T16:16:18.539926Z", "shell.execute_reply": "2025-04-09T16:16:18.538850Z", "shell.execute_reply.started": "2025-04-09T16:16:18.519516Z" }, "tags": [] }, "outputs": [], "source": [ "from torch import nn\n", "\n", "class ClassificationModel(nn.Module):\n", " def __init__(self, base_model):\n", " super(ClassificationModel, self).__init__()\n", " self.base_model = base_model\n", " # self.linear = nn.Linear(768, 8)\n", " # self.softmax = nn.Softmax(dim=1)\n", " self.classifier = nn.Sequential(\n", " nn.Linear(768, 256), # Optional intermediate layer\n", " nn.ReLU(),\n", " nn.Dropout(0.3),\n", " nn.Linear(256, 8), # Final layer to 8 classes\n", " nn.LogSoftmax(dim=1)\n", " )\n", "\n", " def forward(self, input_ids, attention_mask):\n", " hidden_states = self.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state # shape: [batch, max_len, hidden_size]\n", " cls_output = hidden_states[:, 0, :] # Use [CLS] token\n", " probs = self.classifier(cls_output) # (batch_size, 8), now probabilities\n", " return probs\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "243c7b28-fbec-4dad-9602-17691cd6e9c6", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:18.541653Z", "iopub.status.busy": "2025-04-09T16:16:18.541121Z", "iopub.status.idle": "2025-04-09T16:16:27.774174Z", "shell.execute_reply": "2025-04-09T16:16:27.773344Z", "shell.execute_reply.started": "2025-04-09T16:16:18.541615Z" }, "tags": [] }, "outputs": [], "source": [ "class_model = ClassificationModel(model)\n", "\n", "# class_model = torch.load(\"model_weights.pth\")\n", "class_model.load_state_dict(torch.load(\"model_weights.pth\", weights_only=True))\n", "\n", "class_model.to(device)\n", "\n", "for param in class_model.base_model.parameters():\n", " param.requires_grad = False\n", "\n", "# print(class_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']).shape)" ] }, { "cell_type": "code", "execution_count": 16, "id": "6e656bda-bd95-4d26-a8f4-b83841fd8ca6", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:16:27.775874Z", "iopub.status.busy": "2025-04-09T16:16:27.775265Z", "iopub.status.idle": "2025-04-09T16:16:27.796354Z", "shell.execute_reply": "2025-04-09T16:16:27.795408Z", "shell.execute_reply.started": "2025-04-09T16:16:27.775837Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trainable parameters: 198920\n" ] } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "print(\"Trainable parameters:\", count_parameters(class_model))" ] }, { "cell_type": "markdown", "id": "1b69303c-d240-417f-a9ba-fcebdc7fe8be", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T10:15:18.556131Z", "iopub.status.busy": "2025-04-09T10:15:18.554843Z", "iopub.status.idle": "2025-04-09T10:15:18.587934Z", "shell.execute_reply": "2025-04-09T10:15:18.586812Z", "shell.execute_reply.started": "2025-04-09T10:15:18.556083Z" }, "tags": [] }, "source": [ "***Training***" ] }, { "cell_type": "code", "execution_count": 33, "id": "b1139bae-5008-418d-bad2-458f34381d2f", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:29:57.827095Z", "iopub.status.busy": "2025-04-09T16:29:57.825931Z", "iopub.status.idle": "2025-04-09T16:29:57.863867Z", "shell.execute_reply": "2025-04-09T16:29:57.863041Z", "shell.execute_reply.started": "2025-04-09T16:29:57.827049Z" }, "tags": [] }, "outputs": [], "source": [ "from IPython.display import clear_output\n", "import warnings\n", "import time\n", "from datetime import timedelta\n", "import numpy as np\n", "from collections import defaultdict\n", "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "def get_lr(optimizer):\n", " for param_group in optimizer.param_groups:\n", " return param_group['lr']\n", "\n", "def learning_loop(\n", " model,\n", " optimizer,\n", " train_dataloader,\n", " val_dataloader,\n", " criterion,\n", " scheduler=None,\n", " min_lr=None,\n", " epochs=3,\n", " val_every=100,\n", " draw_every=50,\n", " separate_show=False,\n", " model_name=None,\n", " chkp_folder=\"./chkps\",\n", " metric_names=None,\n", "):\n", "\n", " device = next(model.parameters()).device\n", " dtype = next(model.parameters()).dtype\n", "\n", " losses = {'train': [], 'val': [], 'accuracy_val': [], 'lr': []}\n", " # lrs = []\n", " best_val_loss = np.Inf\n", " if metric_names is not None:\n", " metrics = defaultdict(list)\n", " start_time = time.monotonic()\n", " \n", " # [1, 1, 0, 0, 0, 1, 1, 0, 0, 0]\n", " idx = 0\n", "\n", " for epoch in range(1, epochs + 1):\n", " model.train()\n", "\n", " for batch_idx, batch in enumerate(train_dataloader):\n", " if idx % 10 == 0:\n", " print(idx)\n", " idx += 1\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = (batch[\"labels\"]).to(device)\n", "\n", " # attention_mask = attention_mask.to(dtype=dtype)\n", "\n", " optimizer.zero_grad()\n", "\n", " model_probs = model(input_ids, attention_mask)\n", " \n", " # print(torch.sum(model_probs))\n", " # print(torch.sum(labels))\n", "\n", " model_loss = criterion(\n", " model_probs,\n", " labels\n", " )\n", " # print(model_loss.item())\n", " \n", " model_loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", " \n", " current_lr = optimizer.param_groups[0]['lr']\n", " losses['lr'].append(current_lr)\n", " losses['train'].append(model_loss.item())\n", "\n", " # validation\n", " if idx % val_every == 0:\n", " model.eval()\n", "\n", "\n", " with torch.no_grad():\n", " for idx_val, batch in enumerate(val_dataloader):\n", " if idx_val == 10:\n", " break\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = (batch[\"labels\"]).to(device)\n", "\n", " # attention_mask = attention_mask.to(dtype=dtype)\n", "\n", " model_probs_val = model(input_ids, attention_mask)\n", "\n", " val_loss = criterion(\n", " model_probs_val,\n", " labels\n", " )\n", " \n", " losses['val'].append(val_loss.item())\n", " \n", "\n", " \n", " torch.cuda.empty_cache()\n", " model.train()\n", "\n", " # plotting\n", " if idx % draw_every == 0:\n", " clear_output(True)\n", " plt.clf()\n", " plt.figure(figsize=(10, 5))\n", "\n", " plt.subplot(1, 2, 1)\n", " plt.plot(losses['train'], label='train_loss')\n", " plt.xlabel('Iter')\n", " plt.ylabel('Loss')\n", " plt.title('Training Losses')\n", " plt.legend()\n", " plt.grid(True)\n", "\n", " plt.subplot(1, 2, 2)\n", " plt.plot(losses['val'], label='val_loss')\n", " plt.xlabel('Iter')\n", " plt.ylabel('Loss')\n", " plt.title('Validation Losses')\n", " plt.legend()\n", " plt.grid(True)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " return losses" ] }, { "cell_type": "code", "execution_count": 34, "id": "7871cabe-52fc-4169-86f3-1672e73d4690", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:30:00.983579Z", "iopub.status.busy": "2025-04-09T16:30:00.982492Z", "iopub.status.idle": "2025-04-09T16:30:01.016961Z", "shell.execute_reply": "2025-04-09T16:30:01.015926Z", "shell.execute_reply.started": "2025-04-09T16:30:00.983534Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11700\n" ] } ], "source": [ "from torch.optim.lr_scheduler import CosineAnnealingLR\n", "from torch.optim.lr_scheduler import LinearLR\n", "import torch.optim as optim\n", "\n", "\n", "num_epochs = 3\n", "\n", "total_steps = num_epochs * len(train_dataloader) # Total training steps\n", "\n", "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", "\n", "scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=0)\n", "\n", "print(total_steps)" ] }, { "cell_type": "code", "execution_count": 35, "id": "b0eab979-8ec7-44e4-bfc2-f6409de09192", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T16:30:03.273412Z", "iopub.status.busy": "2025-04-09T16:30:03.272346Z", "iopub.status.idle": "2025-04-09T17:19:23.714816Z", "shell.execute_reply": "2025-04-09T17:19:23.714027Z", "shell.execute_reply.started": "2025-04-09T16:30:03.273370Z" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "\n", "losses = learning_loop(\n", " class_model,\n", " optimizer,\n", " train_dataloader,\n", " val_dataloader,\n", " # nn.CrossEntropyLoss(),\n", " nn.KLDivLoss(reduction=\"batchmean\"),\n", " scheduler=scheduler,\n", " min_lr=None,\n", " epochs=num_epochs,\n", " val_every=20,\n", " draw_every=50,\n", " separate_show=False,\n", " model_name=None,\n", " chkp_folder=\"./chkps\",\n", " metric_names=None,\n", ")\n" ] }, { "cell_type": "code", "execution_count": 37, "id": "e12128f5-2809-474e-9cf8-2274833a2ae0", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T17:19:32.699426Z", "iopub.status.busy": "2025-04-09T17:19:32.698525Z", "iopub.status.idle": "2025-04-09T17:19:50.895279Z", "shell.execute_reply": "2025-04-09T17:19:50.894467Z", "shell.execute_reply.started": "2025-04-09T17:19:32.699388Z" }, "tags": [] }, "outputs": [], "source": [ "torch.save(class_model.state_dict(), \"model_weights.pth\")\n", "torch.save(class_model.state_dict(), \"pytorch_model.bin\")" ] }, { "cell_type": "code", "execution_count": 38, "id": "5dbcbe7f-46f4-45bd-a645-2a90f1528f53", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T17:42:37.613236Z", "iopub.status.busy": "2025-04-09T17:42:37.612271Z", "iopub.status.idle": "2025-04-09T17:42:37.656166Z", "shell.execute_reply": "2025-04-09T17:42:37.655273Z", "shell.execute_reply.started": "2025-04-09T17:42:37.613193Z" }, "tags": [] }, "outputs": [], "source": [ "def inference(\n", " title,\n", " abstract,\n", " threshold=0.50\n", "):\n", " cur_elem = title + '@' + abstract\n", "\n", " # Tokenize input and label\n", " encoding = tokenizer(cur_elem, padding=\"max_length\", truncation=True, max_length=MAX_LENGTH, return_tensors=\"pt\")\n", " input_ids = encoding[\"input_ids\"].to(device)\n", " attention_mask = encoding[\"attention_mask\"].to(device)\n", " \n", " # input_ids.to(device)\n", " attention_mask.to(device)\n", " class_model.to(device)\n", " \n", " res_probs = torch.exp(class_model(input_ids, attention_mask))\n", " \n", " print(res_probs)\n", " \n", " probs = res_probs.squeeze(0) # (8,)\n", " \n", " sorted_probs, sorted_indices = torch.sort(probs, descending=True)\n", "\n", " total = 0.0\n", " selected_indices = []\n", " \n", " for prob, idx in zip(sorted_probs, sorted_indices):\n", " total += prob.item()\n", " selected_indices.append(idx.item())\n", " if total >= threshold:\n", " break\n", "\n", " ans_themes = [ind_to_category[elem] for elem in selected_indices]\n", " return ans_themes" ] }, { "cell_type": "code", "execution_count": 39, "id": "f48d8427-7965-47d2-8431-2eb10d8ac7c5", "metadata": { "execution": { "iopub.execute_input": "2025-04-09T17:42:40.652581Z", "iopub.status.busy": "2025-04-09T17:42:40.651551Z", "iopub.status.idle": "2025-04-09T17:42:40.724918Z", "shell.execute_reply": "2025-04-09T17:42:40.724020Z", "shell.execute_reply.started": "2025-04-09T17:42:40.652541Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.1338, 0.1863, 0.1153, 0.1085, 0.1025, 0.1116, 0.1232, 0.1188]],\n", " device='cuda:0', grad_fn=)\n" ] }, { "data": { "text/plain": [ "['Economics', 'Computer Science', 'Quantitative Finance', 'Statistics']" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inference('Dual Recurrent Attention Units for Visual Question Answering', 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.')" ] }, { "cell_type": "code", "execution_count": null, "id": "409c347c-c439-40ec-a405-823eb64f6ae4", "metadata": { "execution": { "iopub.status.busy": "2025-04-09T16:16:27.891909Z", "iopub.status.idle": "2025-04-09T16:16:27.892360Z", "shell.execute_reply": "2025-04-09T16:16:27.892167Z", "shell.execute_reply.started": "2025-04-09T16:16:27.892146Z" }, "tags": [] }, "outputs": [], "source": [ "print(device)" ] } ], "metadata": { "kernelspec": { "display_name": "DataSphere Kernel", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }