{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os; os.chdir('..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"from datasets import Dataset, load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" category | \n",
" label | \n",
" label_id | \n",
"
\n",
" \n",
" \n",
" \n",
" 3982 | \n",
" Citation context relevance assessment platforms | \n",
" Reference | \n",
" 12 | \n",
"
\n",
" \n",
" 24651 | \n",
" Geology fieldwork | \n",
" Science | \n",
" 2 | \n",
"
\n",
" \n",
" 28113 | \n",
" Password management for individuals | \n",
" Computers_and_Electronics | \n",
" 7 | \n",
"
\n",
" \n",
" 10999 | \n",
" Real estate market statistics | \n",
" Real Estate | \n",
" 24 | \n",
"
\n",
" \n",
" 17096 | \n",
" Running gear for women | \n",
" Beauty_and_Fitness | \n",
" 9 | \n",
"
\n",
" \n",
" 2374 | \n",
" Sports Team Fan Pride | \n",
" Sports | \n",
" 26 | \n",
"
\n",
" \n",
" 9932 | \n",
" Wine and food events | \n",
" Food_and_Drink | \n",
" 15 | \n",
"
\n",
" \n",
" 2953 | \n",
" College admissions for aspiring dancers | \n",
" Jobs_and_Education | \n",
" 21 | \n",
"
\n",
" \n",
" 25038 | \n",
" Software development best practices forums | \n",
" Online Communities | \n",
" 8 | \n",
"
\n",
" \n",
" 29703 | \n",
" Quantum physics theories | \n",
" Science | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" category \\\n",
"3982 Citation context relevance assessment platforms \n",
"24651 Geology fieldwork \n",
"28113 Password management for individuals \n",
"10999 Real estate market statistics \n",
"17096 Running gear for women \n",
"2374 Sports Team Fan Pride \n",
"9932 Wine and food events \n",
"2953 College admissions for aspiring dancers \n",
"25038 Software development best practices forums \n",
"29703 Quantum physics theories \n",
"\n",
" label label_id \n",
"3982 Reference 12 \n",
"24651 Science 2 \n",
"28113 Computers_and_Electronics 7 \n",
"10999 Real Estate 24 \n",
"17096 Beauty_and_Fitness 9 \n",
"2374 Sports 26 \n",
"9932 Food_and_Drink 15 \n",
"2953 Jobs_and_Education 21 \n",
"25038 Online Communities 8 \n",
"29703 Science 2 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df= pd.read_csv('data_categories/Final_Category_Data_With_Labels.csv')\n",
"\n",
"\n",
"df.sample(10)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" category | \n",
" label_id | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Internet usage monitoring | \n",
" 25 | \n",
"
\n",
" \n",
" 1 | \n",
" Food safety guidelines and regulations | \n",
" 15 | \n",
"
\n",
" \n",
" 2 | \n",
" Internet protocols and edge computing in finance | \n",
" 25 | \n",
"
\n",
" \n",
" 3 | \n",
" Online grocery shopping | \n",
" 15 | \n",
"
\n",
" \n",
" 4 | \n",
" Writing retreats for poets and novelists | \n",
" 17 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" category label_id\n",
"0 Internet usage monitoring 25\n",
"1 Food safety guidelines and regulations 15\n",
"2 Internet protocols and edge computing in finance 25\n",
"3 Online grocery shopping 15\n",
"4 Writing retreats for poets and novelists 17"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_new= df[['category', 'label_id']]\n",
"df_new.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False 22474\n",
"True 11138\n",
"Name: count, dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_new.duplicated().value_counts() # 10837 duplicate values"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_139501/984288843.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df_new.rename(\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" label | \n",
"
\n",
" \n",
" \n",
" \n",
" 2925 | \n",
" Kids' toy stores online | \n",
" 13 | \n",
"
\n",
" \n",
" 31108 | \n",
" Birdwatching apps for bird behavior | \n",
" 5 | \n",
"
\n",
" \n",
" 6817 | \n",
" Legal developments | \n",
" 1 | \n",
"
\n",
" \n",
" 20037 | \n",
" Citation context relevance assessment tools | \n",
" 12 | \n",
"
\n",
" \n",
" 18928 | \n",
" Orchid care guide | \n",
" 20 | \n",
"
\n",
" \n",
" 33358 | \n",
" Scientific publications and journals | \n",
" 2 | \n",
"
\n",
" \n",
" 16499 | \n",
" Service dog etiquette | \n",
" 5 | \n",
"
\n",
" \n",
" 26484 | \n",
" Social media trends analysis | \n",
" 25 | \n",
"
\n",
" \n",
" 15543 | \n",
" Troubleshooting computer issues | \n",
" 7 | \n",
"
\n",
" \n",
" 15854 | \n",
" large | \n",
" 23 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text label\n",
"2925 Kids' toy stores online 13\n",
"31108 Birdwatching apps for bird behavior 5\n",
"6817 Legal developments 1\n",
"20037 Citation context relevance assessment tools 12\n",
"18928 Orchid care guide 20\n",
"33358 Scientific publications and journals 2\n",
"16499 Service dog etiquette 5\n",
"26484 Social media trends analysis 25\n",
"15543 Troubleshooting computer issues 7\n",
"15854 large 23"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_new.rename(\n",
" columns={\n",
" \"category\": \"text\", \n",
" \"label_id\": \"label\"\n",
"}, \n",
" inplace=True\n",
")\n",
"\n",
"df_new.sample(10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/pyarrow/pandas_compat.py:373: FutureWarning: is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n",
" if _pandas_api.is_sparse(col):\n"
]
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 33612\n",
"})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_df= Dataset.from_pandas(df_new)\n",
"dataset_df"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 26889\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 6723\n",
" })\n",
"})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_data= dataset_df.train_test_split(test_size=0.2)\n",
"new_data"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def preprocess_function(examples):\n",
" return tokenizer(examples[\"text\"], truncation=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 48%|████▊ | 13000/26889 [00:00<00:00, 32226.42 examples/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 26889/26889 [00:00<00:00, 34388.34 examples/s]\n",
"Map: 100%|██████████| 6723/6723 [00:00<00:00, 41978.69 examples/s]\n"
]
}
],
"source": [
"tokenized_df = new_data.map(preprocess_function, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-13 10:29:49.212220: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-10-13 10:29:50.573292: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
"# from transformers import DataCollatorWithPadding\n",
"\n",
"# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n",
"\n",
"\n",
"\n",
"\n",
"from transformers import DataCollatorWithPadding\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"\n",
"accuracy = evaluate.load(\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return accuracy.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import json\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Hobbies_and_Leisure': 0,\n",
" 'News': 1,\n",
" 'Science': 2,\n",
" 'Autos_and_Vehicles': 3,\n",
" 'Health': 4,\n",
" 'Pets_and_Animals': 5,\n",
" 'Adult': 6,\n",
" 'Computers_and_Electronics': 7,\n",
" 'Online Communities': 8,\n",
" 'Beauty_and_Fitness': 9,\n",
" 'People_and_Society': 10,\n",
" 'Business_and_Industrial': 11,\n",
" 'Reference': 12,\n",
" 'Shopping': 13,\n",
" 'Travel_and_Transportation': 14,\n",
" 'Food_and_Drink': 15,\n",
" 'Law_and_Government': 16,\n",
" 'Books_and_Literature': 17,\n",
" 'Finance': 18,\n",
" 'Games': 19,\n",
" 'Home_and_Garden': 20,\n",
" 'Jobs_and_Education': 21,\n",
" 'Arts_and_Entertainment': 22,\n",
" 'Sensitive Subjects': 23,\n",
" 'Real Estate': 24,\n",
" 'Internet_and_Telecom': 25,\n",
" 'Sports': 26}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label2id= json.load(\n",
" open('data/categories_refined.json', 'r')\n",
")\n",
"label2id"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{0: 'Hobbies_and_Leisure',\n",
" 1: 'News',\n",
" 2: 'Science',\n",
" 3: 'Autos_and_Vehicles',\n",
" 4: 'Health',\n",
" 5: 'Pets_and_Animals',\n",
" 6: 'Adult',\n",
" 7: 'Computers_and_Electronics',\n",
" 8: 'Online Communities',\n",
" 9: 'Beauty_and_Fitness',\n",
" 10: 'People_and_Society',\n",
" 11: 'Business_and_Industrial',\n",
" 12: 'Reference',\n",
" 13: 'Shopping',\n",
" 14: 'Travel_and_Transportation',\n",
" 15: 'Food_and_Drink',\n",
" 16: 'Law_and_Government',\n",
" 17: 'Books_and_Literature',\n",
" 18: 'Finance',\n",
" 19: 'Games',\n",
" 20: 'Home_and_Garden',\n",
" 21: 'Jobs_and_Education',\n",
" 22: 'Arts_and_Entertainment',\n",
" 23: 'Sensitive Subjects',\n",
" 24: 'Real Estate',\n",
" 25: 'Internet_and_Telecom',\n",
" 26: 'Sports'}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"id2label= {}\n",
"for key in label2id.keys():\n",
" id2label[label2id[key]] = key\n",
" \n",
"id2label"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"finetuned_entity_categorical_classification/checkpoint-3346\", num_labels=27, id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [3362/3362 01:52, Epoch 2/2]\n",
"
\n",
" \n",
" \n",
" \n",
" Epoch | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Accuracy | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0.102300 | \n",
" 0.077652 | \n",
" 0.975309 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.083400 | \n",
" 0.086291 | \n",
" 0.974714 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=3362, training_loss=0.08880683540376008, metrics={'train_runtime': 113.5357, 'train_samples_per_second': 473.666, 'train_steps_per_second': 29.612, 'total_flos': 213673546900476.0, 'train_loss': 0.08880683540376008, 'epoch': 2.0})"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"finetuned_entity_categorical_classification\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=2,\n",
" weight_decay=0.01,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
" # push_to_hub=True,\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_df[\"train\"],\n",
" eval_dataset=tokenized_df[\"test\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}