Spaces:
Sleeping
Sleeping
File size: 65,556 Bytes
0bfbe31 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import ast\n",
"import pandas as pd\n",
"import kagglehub\n",
"from kagglehub import KaggleDatasetAdapter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# file_path = kagglehub.dataset_download(\"neelshah18/arxivdataset\")\n",
"# arxiv_df = pd.read_json(os.path.join(file_path, 'arxivData.json'))\n",
"file_path = \"~/.cache/kagglehub/datasets/neelshah18/arxivdataset/versions/2/arxivData.json\"\n",
"arxiv_df = pd.read_json(file_path)\n",
"arxiv_df = arxiv_df.drop(columns=['author', 'day', 'id', 'link', 'month', 'year'])\n",
"arxiv_df['tag'] = arxiv_df['tag'].apply(ast.literal_eval)\n",
"arxiv_df = arxiv_df.explode('tag').reset_index(drop=True)\n",
"arxiv_df['tag'] = arxiv_df['tag'].apply(lambda x: x['term'])\n",
"arxiv_df['text'] = arxiv_df['title'] + ' ' + arxiv_df['summary']\n",
"arxiv_df = arxiv_df.drop(columns=['title', 'summary'])\n",
"arxiv_df = arxiv_df[['text', 'tag']]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
" <td>cs.AI</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
" <td>cs.CL</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
" <td>cs.CV</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
" <td>cs.NE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
" <td>stat.ML</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Sequential Short-Text Classification with Recu...</td>\n",
" <td>cs.CL</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Sequential Short-Text Classification with Recu...</td>\n",
" <td>cs.AI</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>Sequential Short-Text Classification with Recu...</td>\n",
" <td>cs.LG</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Sequential Short-Text Classification with Recu...</td>\n",
" <td>cs.NE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Sequential Short-Text Classification with Recu...</td>\n",
" <td>stat.ML</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
" <td>cs.CL</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
" <td>cs.AI</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
" <td>cs.LG</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
" <td>cs.NE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
" <td>stat.ML</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text tag\n",
"0 Dual Recurrent Attention Units for Visual Ques... cs.AI\n",
"1 Dual Recurrent Attention Units for Visual Ques... cs.CL\n",
"2 Dual Recurrent Attention Units for Visual Ques... cs.CV\n",
"3 Dual Recurrent Attention Units for Visual Ques... cs.NE\n",
"4 Dual Recurrent Attention Units for Visual Ques... stat.ML\n",
"5 Sequential Short-Text Classification with Recu... cs.CL\n",
"6 Sequential Short-Text Classification with Recu... cs.AI\n",
"7 Sequential Short-Text Classification with Recu... cs.LG\n",
"8 Sequential Short-Text Classification with Recu... cs.NE\n",
"9 Sequential Short-Text Classification with Recu... stat.ML\n",
"10 Multiresolution Recurrent Neural Networks: An ... cs.CL\n",
"11 Multiresolution Recurrent Neural Networks: An ... cs.AI\n",
"12 Multiresolution Recurrent Neural Networks: An ... cs.LG\n",
"13 Multiresolution Recurrent Neural Networks: An ... cs.NE\n",
"14 Multiresolution Recurrent Neural Networks: An ... stat.ML"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"arxiv_df.head(15)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from torch.utils.data import Dataset, DataLoader, random_split\n",
"\n",
"class ArticleDataset(Dataset):\n",
" def __init__(self, data, tokenizer, label_encoder, max_length=256):\n",
" self.tokenizer = tokenizer\n",
" self.label_encoder = label_encoder\n",
" self.max_length = max_length\n",
" self.texts = data['text'].to_list()\n",
" self.labels = torch.tensor(self.label_encoder.fit_transform(data['tag'].to_list()))\n",
" assert len(self.texts) == len(self.labels)\n",
" \n",
" def __getitem__(self, index):\n",
" encoded_text = self.tokenizer(\n",
" self.texts[index],\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" max_length=self.max_length,\n",
" return_tensors=\"pt\"\n",
" )\n",
" return encoded_text['input_ids'].squeeze(0), encoded_text['attention_mask'].squeeze(0), self.labels[index]\n",
"\n",
" def __len__(self):\n",
" return len(self.labels)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"from transformers import DistilBertTokenizer, DistilBertForSequenceClassification\n",
"\n",
"tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
"label_encoder = LabelEncoder()\n",
"\n",
"dataset = ArticleDataset(arxiv_df, tokenizer, label_encoder)\n",
"\n",
"train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"model = DistilBertForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=len(label_encoder.classes_))\n",
"model = model.to(device)\n",
"optimizer = optim.Adam(model.parameters(), lr=2e-5)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"from tqdm import tqdm\n",
"\n",
"train_losses = []\n",
"train_accuracies = []\n",
"\n",
"os.makedirs(\"checkpoints\", exist_ok=True)\n",
"\n",
"def train(model, epochs):\n",
" model.train()\n",
" for epoch in range(epochs):\n",
" print(f\"\\nEpoch {epoch+1}/{epochs}\")\n",
" \n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for input_ids, attn_mask, labels in tqdm(train_loader):\n",
" input_ids = input_ids.to(device)\n",
" attn_mask = attn_mask.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" outputs = model(input_ids, attention_mask=attn_mask, labels=labels)\n",
" loss = outputs.loss\n",
" logits = outputs.logits\n",
"\n",
" total_loss += loss.item()\n",
" preds = torch.argmax(logits, dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Средний loss и accuracy за эпоху\n",
" avg_loss = total_loss / len(train_loader)\n",
" accuracy = correct / total\n",
"\n",
" train_losses.append(avg_loss)\n",
" train_accuracies.append(accuracy)\n",
"\n",
" print(f\"Loss: {avg_loss:.4f} | Accuracy: {accuracy:.4f}\")\n",
"\n",
" # График потерь\n",
" clear_output(True)\n",
" plt.figure(figsize=(10, 4))\n",
" plt.subplot(1, 2, 1)\n",
" plt.plot(train_losses, marker='o')\n",
" plt.title(\"Training Loss\")\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Loss\")\n",
"\n",
" # График точности\n",
" plt.subplot(1, 2, 2)\n",
" plt.plot(train_accuracies, marker='o', color='green')\n",
" plt.title(\"Training Accuracy\")\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Accuracy\")\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Сохраняем чекпойнт\n",
" if (epoch + 1) % 4 == 0:\n",
" checkpoint_path = f\"checkpoints/epoch_{epoch+1}.pt\"\n",
" torch.save(model.state_dict(), checkpoint_path)\n",
" print(f\"Saved checkpoint: {checkpoint_path}\")\n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"dict_to_save = {}\n",
"for ind, el in enumerate(label_encoder.classes_):\n",
" dict_to_save[ind] = el\n",
"\n",
"with open('checkpoints/labels_info.json', 'w') as f:\n",
" json.dump(dict_to_save, f)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Обучаемых параметров: 2353909\n"
]
}
],
"source": [
"for param in model.distilbert.parameters():\n",
" param.requires_grad = False\n",
"\n",
"# Размораживаем только последний слой\n",
"# for param in model.distilbert.transformer.layer[-1].parameters():\n",
"# param.requires_grad = True\n",
"\n",
"# Также размораживаем классификационную голову\n",
"for param in model.classifier.parameters():\n",
" param.requires_grad = True\n",
" \n",
"def count_trainable_params(model):\n",
" total = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
" return total\n",
"\n",
"print(\"Обучаемых параметров:\", count_trainable_params(model))"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved checkpoint: checkpoints/epoch_8.pt\n"
]
}
],
"source": [
"model = train(model=model, epochs=8)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "my_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|