{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "feaf77ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "workding dir: /Users/inflaton/code/engd/papers/maritime/global-incidents\n", "loading env vars from: /Users/inflaton/code/engd/papers/maritime/global-incidents/.env\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "workding_dir = str(Path.cwd().parent)\n", "os.chdir(workding_dir)\n", "sys.path.append(workding_dir)\n", "print(\"workding dir:\", workding_dir)\n", "\n", "from dotenv import find_dotenv, load_dotenv\n", "\n", "found_dotenv = find_dotenv(\".env\")\n", "\n", "if len(found_dotenv) == 0:\n", " found_dotenv = find_dotenv(\".env.example\")\n", "print(f\"loading env vars from: {found_dotenv}\")\n", "load_dotenv(found_dotenv, override=True)" ] }, { "cell_type": "markdown", "id": "3a7dd7d8", "metadata": {}, "source": [ "## Import Statement" ] }, { "cell_type": "code", "execution_count": 3, "id": "86fc25e6", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "markdown", "id": "fac53e88", "metadata": {}, "source": [ "### read the data" ] }, { "cell_type": "code", "execution_count": 5, "id": "dc33b13b", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"data/all_port_labelled.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "31f58fd1", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0IndexUnnamed: 0.1HeadlineDetailsSeverityCategoryRegionDatetimeYear...ITEPNEWCSDRPEMNNMif_labeledMonthWeek
00.08.034.0Grasberg Mine- Grasberg mine workers extend st...Media sources indicate that workers at the Gra...ModerateMine Workers StrikeIndonesia28/5/17 17:082017.0...0.00.00.00.00.00.01.0False5.021.0
11.010.063.0Indonesia: Undersea internet cables damaged by...News sources are stating that recent typhoons ...MinorTravel WarningIndonesia4/9/17 14:302017.0...0.00.00.00.00.01.00.0False4.014.0
\n", "

2 rows × 46 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Index Unnamed: 0.1 \\\n", "0 0.0 8.0 34.0 \n", "1 1.0 10.0 63.0 \n", "\n", " Headline \\\n", "0 Grasberg Mine- Grasberg mine workers extend st... \n", "1 Indonesia: Undersea internet cables damaged by... \n", "\n", " Details Severity \\\n", "0 Media sources indicate that workers at the Gra... Moderate \n", "1 News sources are stating that recent typhoons ... Minor \n", "\n", " Category Region Datetime Year ... IT EP NEW \\\n", "0 Mine Workers Strike Indonesia 28/5/17 17:08 2017.0 ... 0.0 0.0 0.0 \n", "1 Travel Warning Indonesia 4/9/17 14:30 2017.0 ... 0.0 0.0 0.0 \n", "\n", " CSD RPE MN NM if_labeled Month Week \n", "0 0.0 0.0 0.0 1.0 False 5.0 21.0 \n", "1 0.0 0.0 1.0 0.0 False 4.0 14.0 \n", "\n", "[2 rows x 46 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head(2)" ] }, { "cell_type": "markdown", "id": "9bff68c9", "metadata": {}, "source": [ "### Clean empty data" ] }, { "cell_type": "code", "execution_count": 6, "id": "41aa751c", "metadata": {}, "outputs": [], "source": [ "import nltk\n", "from nltk.corpus import stopwords\n", "from nltk.tokenize import word_tokenize\n", "from nltk.stem import WordNetLemmatizer\n", "import string\n", "\n", "# nltk.download('punkt')\n", "# nltk.download('stopwords')\n", "# nltk.download('wordnet')\n", "\n", "\n", "def clean_text(text):\n", " # Lowercase\n", " text = text.lower()\n", " # Tokenization\n", " tokens = word_tokenize(text)\n", " # Removing punctuation\n", " tokens = [word for word in tokens if word not in string.punctuation]\n", " # Removing stop words\n", " stop_words = set(stopwords.words(\"english\"))\n", " tokens = [word for word in tokens if word not in stop_words]\n", " # Lemmatization\n", " lemmatizer = WordNetLemmatizer()\n", " tokens = [lemmatizer.lemmatize(word) for word in tokens]\n", "\n", " return \" \".join(tokens)" ] }, { "cell_type": "code", "execution_count": 7, "id": "6293f613", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package omw-1.4 to\n", "[nltk_data] /Users/inflaton/nltk_data...\n", "[nltk_data] Package omw-1.4 is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import nltk\n", "\n", "nltk.download(\"omw-1.4\")" ] }, { "cell_type": "markdown", "id": "fad3210d", "metadata": {}, "source": [ "### The Details column has an issue\n", "\n", "some of the data are of the type float and none of the text processing functions can be applied to it therefore we have to process it" ] }, { "cell_type": "code", "execution_count": 8, "id": "b1799269", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 5782 entries, 0 to 5781\n", "Data columns (total 2 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Details 5781 non-null object\n", " 1 Category 5780 non-null object\n", "dtypes: object(2)\n", "memory usage: 90.5+ KB\n", "\n", "RangeIndex: 5782 entries, 0 to 5781\n", "Data columns (total 4 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Details 5781 non-null object\n", " 1 Category 5780 non-null object\n", " 2 Details_cleaned 5781 non-null object\n", " 3 Category_cleaned 5780 non-null object\n", "dtypes: object(4)\n", "memory usage: 180.8+ KB\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/7x/56svhln929zdh2xhr3mwqg4r0000gn/T/ipykernel_15258/1896834377.py:3: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " text_df['Details_cleaned'] = text_df['Details'].apply(lambda x: clean_text(x) if not isinstance(x, float) else None)\n", "/var/folders/7x/56svhln929zdh2xhr3mwqg4r0000gn/T/ipykernel_15258/1896834377.py:4: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " text_df['Category_cleaned'] = text_df['Category'].apply(lambda x: None if isinstance(x, float) else x)\n" ] } ], "source": [ "text_df = df[[\"Details\", \"Category\"]]\n", "text_df.info()\n", "text_df[\"Details_cleaned\"] = text_df[\"Details\"].apply(\n", " lambda x: clean_text(x) if not isinstance(x, float) else None\n", ")\n", "text_df[\"Category_cleaned\"] = text_df[\"Category\"].apply(\n", " lambda x: None if isinstance(x, float) else x\n", ")\n", "\n", "# no_nan_df[no_nan_df[\"Details\"].apply(lambda x: print(type(x)))]\n", "# cleaned_df = text_df[text_df[\"Details\"].apply(lambda x: clean_text(x))]\n", "# cleaned_df = df['Details'][1:2]\n", "# type(no_nan_df[\"Details\"][0])\n", "# print(clean_text(no_nan_df[\"Details\"][0]))\n", "text_df.info()" ] }, { "cell_type": "code", "execution_count": 9, "id": "5fcc3b33", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DetailsCategoryDetails_cleanedCategory_cleaned
0Media sources indicate that workers at the Gra...Mine Workers Strikemedium source indicate worker grasberg mine ex...Mine Workers Strike
1News sources are stating that recent typhoons ...Travel Warningnews source stating recent typhoon impact hong...Travel Warning
2The persisting port congestion at Shanghai’s Y...Port Congestionpersisting port congestion shanghai ’ yangshan...Port Congestion
3Updated local media sources from Jakarta indic...Bombing, Police Operationsupdated local medium source jakarta indicate e...Bombing, Police Operations
4According to local police in Jakarta, two expl...Bombing, Police Operationsaccording local police jakarta two explosion c...Bombing, Police Operations
5Severe winds have downed billboards and trees ...Roadway Closure / Disruption, Flooding, Severe...severe wind downed billboard tree bandung wedn...Roadway Closure / Disruption, Flooding, Severe...
6Local media sources indicated on October 29 th...Cargo/Warehouse Theftlocal medium source indicated october 29 wareh...Cargo/Warehouse Theft
7Tropical Storm Rumbia had dissipated after tra...Tropical Cyclone / Stormtropical storm rumbia dissipated travelling ar...Tropical Cyclone / Storm
8Tropical Depression Yutu, also referred to as ...Stormtropical depression yutu also referred `` '' r...Storm
9A magnitude 4.5 earthquake was detected 14 mil...Earthquakemagnitude 4.5 earthquake detected 14 mile nort...Earthquake
\n", "
" ], "text/plain": [ " Details \\\n", "0 Media sources indicate that workers at the Gra... \n", "1 News sources are stating that recent typhoons ... \n", "2 The persisting port congestion at Shanghai’s Y... \n", "3 Updated local media sources from Jakarta indic... \n", "4 According to local police in Jakarta, two expl... \n", "5 Severe winds have downed billboards and trees ... \n", "6 Local media sources indicated on October 29 th... \n", "7 Tropical Storm Rumbia had dissipated after tra... \n", "8 Tropical Depression Yutu, also referred to as ... \n", "9 A magnitude 4.5 earthquake was detected 14 mil... \n", "\n", " Category \\\n", "0 Mine Workers Strike \n", "1 Travel Warning \n", "2 Port Congestion \n", "3 Bombing, Police Operations \n", "4 Bombing, Police Operations \n", "5 Roadway Closure / Disruption, Flooding, Severe... \n", "6 Cargo/Warehouse Theft \n", "7 Tropical Cyclone / Storm \n", "8 Storm \n", "9 Earthquake \n", "\n", " Details_cleaned \\\n", "0 medium source indicate worker grasberg mine ex... \n", "1 news source stating recent typhoon impact hong... \n", "2 persisting port congestion shanghai ’ yangshan... \n", "3 updated local medium source jakarta indicate e... \n", "4 according local police jakarta two explosion c... \n", "5 severe wind downed billboard tree bandung wedn... \n", "6 local medium source indicated october 29 wareh... \n", "7 tropical storm rumbia dissipated travelling ar... \n", "8 tropical depression yutu also referred `` '' r... \n", "9 magnitude 4.5 earthquake detected 14 mile nort... \n", "\n", " Category_cleaned \n", "0 Mine Workers Strike \n", "1 Travel Warning \n", "2 Port Congestion \n", "3 Bombing, Police Operations \n", "4 Bombing, Police Operations \n", "5 Roadway Closure / Disruption, Flooding, Severe... \n", "6 Cargo/Warehouse Theft \n", "7 Tropical Cyclone / Storm \n", "8 Storm \n", "9 Earthquake " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processed_data = text_df.dropna()\n", "processed_data.head(10)" ] }, { "cell_type": "code", "execution_count": 10, "id": "d02b4b00", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "857" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processed_data[\"Category\"].nunique()" ] }, { "cell_type": "code", "execution_count": 11, "id": "9ee856a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 5780 entries, 0 to 5781\n", "Data columns (total 4 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Details 5780 non-null object\n", " 1 Category 5780 non-null object\n", " 2 Details_cleaned 5780 non-null object\n", " 3 Category_cleaned 5780 non-null object\n", "dtypes: object(4)\n", "memory usage: 225.8+ KB\n" ] } ], "source": [ "processed_data.info()" ] }, { "cell_type": "markdown", "id": "3f6d478f", "metadata": {}, "source": [ "## Process the Category column\n", "this is not seldom done as we don't usually process the y of the data\n", "However, the category is too complex and requires processing if not the labels are just too much" ] }, { "cell_type": "code", "execution_count": 12, "id": "285013d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "111\n" ] } ], "source": [ "# Create a function that will split the labels into individual\n", "import re\n", "\n", "\n", "def split_string(text):\n", " # Split the string using either \"/\" or \",\" as separator\n", " words = re.split(r\"[\\/,]\", text)\n", " # Remove any leading or trailing whitespace from each word\n", " words = [word.strip() for word in words if word.strip()]\n", " return words\n", "\n", "\n", "# Example usage:\n", "# input_str = \"Roadway Closure / Disruption, Flooding, Severe Winds, Weather Advisory\"\n", "# result = split_string(input_str)\n", "# print(result)\n", "\n", "# create a list to find the number of unique individual labels\n", "label_list = []\n", "\n", "for i in processed_data[\"Category_cleaned\"]:\n", " for j in split_string(i):\n", " if j not in label_list:\n", " label_list.append(j)\n", "\n", "# print(label)\n", "print(len(label_list))" ] }, { "cell_type": "markdown", "id": "8e7b48e8", "metadata": {}, "source": [ "#### After filtering out the unique labels in the Category column we are still left with 111 labels which is still considered too much" ] }, { "cell_type": "markdown", "id": "33234f8c", "metadata": {}, "source": [ "#### The next step would be to to reduce a data's category label into 1 single label \n", "Previously the data looks like Roadway Closure / Disruption, Flooding, Severe... we need to reduce it to 1 single label \n", "The next process we are going to use in is that we assume the first label in is the most prominent category then we will remove the other categories" ] }, { "cell_type": "code", "execution_count": 14, "id": "12f9b9b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 5780 entries, 0 to 5781\n", "Data columns (total 5 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Details 5780 non-null object\n", " 1 Category 5780 non-null object\n", " 2 Details_cleaned 5780 non-null object\n", " 3 Category_cleaned 5780 non-null object\n", " 4 Category_single 5780 non-null object\n", "dtypes: object(5)\n", "memory usage: 270.9+ KB\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/7x/56svhln929zdh2xhr3mwqg4r0000gn/T/ipykernel_15258/2344116627.py:25: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " text_df['Category_single'] = text_df['Category_cleaned'].apply(lambda x: split_and_get_first(x))\n" ] } ], "source": [ "def split_and_get_first(text):\n", " # Split the string using either \"/\" or \",\" as separator\n", " if text == None:\n", " return None\n", " words = re.split(r\"[\\/,]\", text)\n", " # Remove any leading or trailing whitespace from each word\n", " words = [word.strip() for word in words if word.strip()]\n", " # Return the first word after split\n", " if words:\n", " return words[0]\n", " else:\n", " return None\n", "\n", "\n", "def remove_none_rows(df, column_name):\n", " # Iterate through the DataFrame\n", " for index, value in enumerate(df[column_name]):\n", " # Check if the value is None\n", " if value is None:\n", " # Remove the row where the data belongs to\n", " df = df.drop(index, axis=0)\n", " return df\n", "\n", "\n", "# Example usage:\n", "# input_str = \"Roadway Closure / Disruption, Flooding, Severe Winds, Weather Advisory\"\n", "# result = split_and_get_first(input_str)\n", "# print(result)\n", "text_df[\"Category_single\"] = text_df[\"Category_cleaned\"].apply(\n", " lambda x: split_and_get_first(x)\n", ")\n", "result_df = remove_none_rows(text_df, \"Category_cleaned\")\n", "result_df.info()" ] }, { "cell_type": "code", "execution_count": 15, "id": "b5931fe1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DetailsCategoryDetails_cleanedCategory_cleanedCategory_single
0Media sources indicate that workers at the Gra...Mine Workers Strikemedium source indicate worker grasberg mine ex...Mine Workers StrikeMine Workers Strike
1News sources are stating that recent typhoons ...Travel Warningnews source stating recent typhoon impact hong...Travel WarningTravel Warning
2The persisting port congestion at Shanghai’s Y...Port Congestionpersisting port congestion shanghai ’ yangshan...Port CongestionPort Congestion
3Updated local media sources from Jakarta indic...Bombing, Police Operationsupdated local medium source jakarta indicate e...Bombing, Police OperationsBombing
4According to local police in Jakarta, two expl...Bombing, Police Operationsaccording local police jakarta two explosion c...Bombing, Police OperationsBombing
5Severe winds have downed billboards and trees ...Roadway Closure / Disruption, Flooding, Severe...severe wind downed billboard tree bandung wedn...Roadway Closure / Disruption, Flooding, Severe...Roadway Closure
6Local media sources indicated on October 29 th...Cargo/Warehouse Theftlocal medium source indicated october 29 wareh...Cargo/Warehouse TheftCargo
7Tropical Storm Rumbia had dissipated after tra...Tropical Cyclone / Stormtropical storm rumbia dissipated travelling ar...Tropical Cyclone / StormTropical Cyclone
8Tropical Depression Yutu, also referred to as ...Stormtropical depression yutu also referred `` '' r...StormStorm
9A magnitude 4.5 earthquake was detected 14 mil...Earthquakemagnitude 4.5 earthquake detected 14 mile nort...EarthquakeEarthquake
10Multiple sources report that a magnitude 5.5 e...Earthquakemultiple source report magnitude 5.5 earthquak...EarthquakeEarthquake
11Post-Tropical Cyclone Michael is approximately...Tropical Cyclone / Stormpost-tropical cyclone michael approximately 18...Tropical Cyclone / StormTropical Cyclone
12Industry sources indicate on September 11 that...Workplace Accidentindustry source indicate september 11 2 worker...Workplace AccidentWorkplace Accident
13Government sources are reporting a tornado has...Tornadogovernment source reporting tornado touched tw...TornadoTornado
14Media sources are informing on September 24 th...Industrial Actionmedium source informing september 24 oil worke...Industrial ActionIndustrial Action
15A magnitude 4.5 earthquake was detected in cen...Earthquakemagnitude 4.5 earthquake detected central taiw...EarthquakeEarthquake
16Industry sources indicate on August 31 that th...Port Congestionindustry source indicate august 31 port durban...Port CongestionPort Congestion
17Tropical Depression Gordon continues to weaken...Stormtropical depression gordon continues weaken mo...StormStorm
18Local media sources indicated on November 8 th...Public Safety / Securitylocal medium source indicated november 8 270 k...Public Safety / SecurityPublic Safety
19The European-Mediterranean Seismological Centr...Earthquakeeuropean-mediterranean seismological centre re...EarthquakeEarthquake
\n", "
" ], "text/plain": [ " Details \\\n", "0 Media sources indicate that workers at the Gra... \n", "1 News sources are stating that recent typhoons ... \n", "2 The persisting port congestion at Shanghai’s Y... \n", "3 Updated local media sources from Jakarta indic... \n", "4 According to local police in Jakarta, two expl... \n", "5 Severe winds have downed billboards and trees ... \n", "6 Local media sources indicated on October 29 th... \n", "7 Tropical Storm Rumbia had dissipated after tra... \n", "8 Tropical Depression Yutu, also referred to as ... \n", "9 A magnitude 4.5 earthquake was detected 14 mil... \n", "10 Multiple sources report that a magnitude 5.5 e... \n", "11 Post-Tropical Cyclone Michael is approximately... \n", "12 Industry sources indicate on September 11 that... \n", "13 Government sources are reporting a tornado has... \n", "14 Media sources are informing on September 24 th... \n", "15 A magnitude 4.5 earthquake was detected in cen... \n", "16 Industry sources indicate on August 31 that th... \n", "17 Tropical Depression Gordon continues to weaken... \n", "18 Local media sources indicated on November 8 th... \n", "19 The European-Mediterranean Seismological Centr... \n", "\n", " Category \\\n", "0 Mine Workers Strike \n", "1 Travel Warning \n", "2 Port Congestion \n", "3 Bombing, Police Operations \n", "4 Bombing, Police Operations \n", "5 Roadway Closure / Disruption, Flooding, Severe... \n", "6 Cargo/Warehouse Theft \n", "7 Tropical Cyclone / Storm \n", "8 Storm \n", "9 Earthquake \n", "10 Earthquake \n", "11 Tropical Cyclone / Storm \n", "12 Workplace Accident \n", "13 Tornado \n", "14 Industrial Action \n", "15 Earthquake \n", "16 Port Congestion \n", "17 Storm \n", "18 Public Safety / Security \n", "19 Earthquake \n", "\n", " Details_cleaned \\\n", "0 medium source indicate worker grasberg mine ex... \n", "1 news source stating recent typhoon impact hong... \n", "2 persisting port congestion shanghai ’ yangshan... \n", "3 updated local medium source jakarta indicate e... \n", "4 according local police jakarta two explosion c... \n", "5 severe wind downed billboard tree bandung wedn... \n", "6 local medium source indicated october 29 wareh... \n", "7 tropical storm rumbia dissipated travelling ar... \n", "8 tropical depression yutu also referred `` '' r... \n", "9 magnitude 4.5 earthquake detected 14 mile nort... \n", "10 multiple source report magnitude 5.5 earthquak... \n", "11 post-tropical cyclone michael approximately 18... \n", "12 industry source indicate september 11 2 worker... \n", "13 government source reporting tornado touched tw... \n", "14 medium source informing september 24 oil worke... \n", "15 magnitude 4.5 earthquake detected central taiw... \n", "16 industry source indicate august 31 port durban... \n", "17 tropical depression gordon continues weaken mo... \n", "18 local medium source indicated november 8 270 k... \n", "19 european-mediterranean seismological centre re... \n", "\n", " Category_cleaned Category_single \n", "0 Mine Workers Strike Mine Workers Strike \n", "1 Travel Warning Travel Warning \n", "2 Port Congestion Port Congestion \n", "3 Bombing, Police Operations Bombing \n", "4 Bombing, Police Operations Bombing \n", "5 Roadway Closure / Disruption, Flooding, Severe... Roadway Closure \n", "6 Cargo/Warehouse Theft Cargo \n", "7 Tropical Cyclone / Storm Tropical Cyclone \n", "8 Storm Storm \n", "9 Earthquake Earthquake \n", "10 Earthquake Earthquake \n", "11 Tropical Cyclone / Storm Tropical Cyclone \n", "12 Workplace Accident Workplace Accident \n", "13 Tornado Tornado \n", "14 Industrial Action Industrial Action \n", "15 Earthquake Earthquake \n", "16 Port Congestion Port Congestion \n", "17 Storm Storm \n", "18 Public Safety / Security Public Safety \n", "19 Earthquake Earthquake " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_df.head(20)" ] }, { "cell_type": "code", "execution_count": 16, "id": "9c19b11a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "94" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_df[\"Category_single\"].nunique()" ] }, { "cell_type": "markdown", "id": "29d4037f", "metadata": {}, "source": [ "### After taking out the first label in the Category column we are still left with 94 unique labels\n", "This is still unacceptable amount of labels the next step we are planning to use is to manually group the labels in more generalize label by using a rule based system" ] }, { "cell_type": "code", "execution_count": 17, "id": "10f07d05", "metadata": {}, "outputs": [], "source": [ "### first export the unique labels into excel for better visualization\n", "unique_labels_df = pd.DataFrame({\"String\": label_list})\n", "file_path = \"data/label_list.xlsx\"\n", "\n", "# Save DataFrame to Excel\n", "unique_labels_df.to_excel(file_path, index=False)" ] }, { "attachments": { "converstion.png": { "image/png": "" } }, "cell_type": "markdown", "id": "398e6da8", "metadata": {}, "source": [ "![converstion.png](attachment:converstion.png)" ] }, { "cell_type": "code", "execution_count": 18, "id": "d4357af0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WeatherWorker StrikeAdministrative IssueHuman ErrorCyber AttackTerrorismAccidentOthers
0FloodingMine Workers StrikePort CongestionWorkplace AccidentNetwork DisruptionBombingMaritime AccidentMiscellaneous Events
1Severe WindsProduction HaltPolice OperationsIndividuals in FocusRansomwareWarehouse TheftVehicle AccidentMiscellaneous Strikes
2Weather AdvisoryProtestRoadway ClosureMilitary OperationsData breachPublic SafetyDeathOutbreak of disease
3Tropical CycloneRiotDisruptionFlight DelaysPhishingSecurityInjuryNaN
4StormPort StrikeCargoCancellationsNaNOrganized CrimeNon-industrial FireNaN
\n", "
" ], "text/plain": [ " Weather Worker Strike Administrative Issue \\\n", "0 Flooding Mine Workers Strike Port Congestion \n", "1 Severe Winds Production Halt Police Operations \n", "2 Weather Advisory Protest Roadway Closure \n", "3 Tropical Cyclone Riot Disruption \n", "4 Storm Port Strike Cargo \n", "\n", " Human Error Cyber Attack Terrorism \\\n", "0 Workplace Accident Network Disruption Bombing \n", "1 Individuals in Focus Ransomware Warehouse Theft \n", "2 Military Operations Data breach Public Safety \n", "3 Flight Delays Phishing Security \n", "4 Cancellations NaN Organized Crime \n", "\n", " Accident Others \n", "0 Maritime Accident Miscellaneous Events \n", "1 Vehicle Accident Miscellaneous Strikes \n", "2 Death Outbreak of disease \n", "3 Injury NaN \n", "4 Non-industrial Fire NaN " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_labels_df = pd.read_excel(\"data/new_labels.xlsx\")\n", "new_labels_df.head()" ] }, { "cell_type": "markdown", "id": "407189c9", "metadata": {}, "source": [ "#### convert them into lists" ] }, { "cell_type": "code", "execution_count": 19, "id": "73939327", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Weather\n", "\n", "['Flooding', 'Severe Winds', 'Weather Advisory', 'Tropical Cyclone', 'Storm', 'Earthquake', 'Tornado', 'Typhoon', 'Landslide', 'Water', 'Hurricane', 'Wildfire', 'Blizzard', 'Hail']\n", "\n", "\n", "Worker Strike\n", "\n", "['Mine Workers Strike', 'Production Halt', 'Protest', 'Riot', 'Port Strike', 'General Strike', 'Civil Service Strike', 'Civil Unrest Advisory', 'Cargo Transportation Strike', 'Energy Sector Strike']\n", "\n", "\n", "Administrative Issue\n", "\n", "['Port Congestion', 'Police Operations', 'Roadway Closure', 'Disruption', 'Cargo', 'Industrial Action', 'Port Disruption', 'Cargo Disruption', 'Power Outage', 'Port Closure', 'Maritime Advisory', 'Train Delays', 'Ground Transportation Advisory', 'Public Transportation Disruption', 'Trade Regulation', 'Customs Regulation', 'Regulatory Advisory', 'Industry Directives', 'Security Advisory', 'Public Holidays', 'Customs Delay', 'Public Health Advisory', 'Detention', 'Aviation Advisory', 'Waterway closure', 'Waterway Closure', 'Plant Closure', 'Border Closure', 'Delay', 'Industrial zone shutdown', 'Trade Restrictions', 'Closure', 'Truck Driving Ban', 'Insolvency', 'Environmental Regulations', 'Postal Disruption', 'Ice Storm', 'Travel Warning']\n", "\n", "\n", "Human Error\n", "\n", "['Workplace Accident', 'Individuals in Focus', 'Military Operations', 'Flight Delays', 'Cancellations', 'Political Info', 'Event']\n", "\n", "\n", "Cyber Attack\n", "\n", "['Network Disruption', 'Ransomware', 'Data breach', 'Phishing']\n", "\n", "\n", "Terrorism\n", "\n", "['Bombing', 'Warehouse Theft', 'Public Safety', 'Security', 'Organized Crime', 'Hazmat Response', 'Piracy', 'Kidnap', 'Shooting', 'Robbery', 'Cargo theft', 'Bomb Detonation', 'Terror Attack', 'Outbreak Of War', 'Militant Action']\n", "\n", "\n", "Accident\n", "\n", "['Maritime Accident', 'Vehicle Accident', 'Death', 'Injury', 'Non-industrial Fire', 'Chemical Spill', 'Industrial Fire', 'Fuel Disruption', 'Airline Incident', 'Crash', 'Explosion', 'Train Accident', 'Derailment', 'Sewage Disruption', 'Barge Accident', 'Bridge Collapse', 'Structure Collapse', 'Airport Accident', 'Force Majeure', 'Telecom Outage']\n", "\n", "\n", "Others\n", "\n", "['Miscellaneous Events', 'Miscellaneous Strikes', 'Outbreak of disease']\n" ] } ], "source": [ "new_labels_dict = new_labels_df.to_dict(orient=\"list\")\n", "\n", "\n", "for key, value in new_labels_dict.items():\n", " new_labels_dict[key] = [item for item in value if not pd.isnull(item)]\n", "\n", "for category in new_labels_dict:\n", " print(\"\\n\")\n", " print(category + \"\\n\")\n", " print(new_labels_dict[category])" ] }, { "cell_type": "markdown", "id": "8516af0e", "metadata": {}, "source": [ "### create a new column with the summarized label" ] }, { "cell_type": "code", "execution_count": 20, "id": "0d316bb4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DetailsCategoryDetails_cleanedCategory_cleanedCategory_singleSummarized_label
0Media sources indicate that workers at the Gra...Mine Workers Strikemedium source indicate worker grasberg mine ex...Mine Workers StrikeMine Workers StrikeWorker Strike
1News sources are stating that recent typhoons ...Travel Warningnews source stating recent typhoon impact hong...Travel WarningTravel WarningAdministrative Issue
2The persisting port congestion at Shanghai’s Y...Port Congestionpersisting port congestion shanghai ’ yangshan...Port CongestionPort CongestionAdministrative Issue
3Updated local media sources from Jakarta indic...Bombing, Police Operationsupdated local medium source jakarta indicate e...Bombing, Police OperationsBombingTerrorism
4According to local police in Jakarta, two expl...Bombing, Police Operationsaccording local police jakarta two explosion c...Bombing, Police OperationsBombingTerrorism
.....................
5777Intelligence received by Everstream Analytics ...Ice Stormintelligence received everstream analytics ind...Ice StormIce StormAdministrative Issue
5778Meteorological sources indicate that a series ...Roadway Closure / Disruption, Ground Transport...meteorological source indicate series winter s...Roadway Closure / Disruption, Ground Transport...Roadway ClosureAdministrative Issue
5779Industry sources report on December 7 that Svi...Industrial Actionindustry source report december 7 svitzer aust...Industrial ActionIndustrial ActionAdministrative Issue
5780Industry sources indicate on December 14 that ...Port Strikeindustry source indicate december 14 worker dp...Port StrikePort StrikeWorker Strike
5781On November 17, Dutch media sources reported t...Port Strikenovember 17 dutch medium source reported worke...Port StrikePort StrikeWorker Strike
\n", "

5780 rows × 6 columns

\n", "
" ], "text/plain": [ " Details \\\n", "0 Media sources indicate that workers at the Gra... \n", "1 News sources are stating that recent typhoons ... \n", "2 The persisting port congestion at Shanghai’s Y... \n", "3 Updated local media sources from Jakarta indic... \n", "4 According to local police in Jakarta, two expl... \n", "... ... \n", "5777 Intelligence received by Everstream Analytics ... \n", "5778 Meteorological sources indicate that a series ... \n", "5779 Industry sources report on December 7 that Svi... \n", "5780 Industry sources indicate on December 14 that ... \n", "5781 On November 17, Dutch media sources reported t... \n", "\n", " Category \\\n", "0 Mine Workers Strike \n", "1 Travel Warning \n", "2 Port Congestion \n", "3 Bombing, Police Operations \n", "4 Bombing, Police Operations \n", "... ... \n", "5777 Ice Storm \n", "5778 Roadway Closure / Disruption, Ground Transport... \n", "5779 Industrial Action \n", "5780 Port Strike \n", "5781 Port Strike \n", "\n", " Details_cleaned \\\n", "0 medium source indicate worker grasberg mine ex... \n", "1 news source stating recent typhoon impact hong... \n", "2 persisting port congestion shanghai ’ yangshan... \n", "3 updated local medium source jakarta indicate e... \n", "4 according local police jakarta two explosion c... \n", "... ... \n", "5777 intelligence received everstream analytics ind... \n", "5778 meteorological source indicate series winter s... \n", "5779 industry source report december 7 svitzer aust... \n", "5780 industry source indicate december 14 worker dp... \n", "5781 november 17 dutch medium source reported worke... \n", "\n", " Category_cleaned Category_single \\\n", "0 Mine Workers Strike Mine Workers Strike \n", "1 Travel Warning Travel Warning \n", "2 Port Congestion Port Congestion \n", "3 Bombing, Police Operations Bombing \n", "4 Bombing, Police Operations Bombing \n", "... ... ... \n", "5777 Ice Storm Ice Storm \n", "5778 Roadway Closure / Disruption, Ground Transport... Roadway Closure \n", "5779 Industrial Action Industrial Action \n", "5780 Port Strike Port Strike \n", "5781 Port Strike Port Strike \n", "\n", " Summarized_label \n", "0 Worker Strike \n", "1 Administrative Issue \n", "2 Administrative Issue \n", "3 Terrorism \n", "4 Terrorism \n", "... ... \n", "5777 Administrative Issue \n", "5778 Administrative Issue \n", "5779 Administrative Issue \n", "5780 Worker Strike \n", "5781 Worker Strike \n", "\n", "[5780 rows x 6 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_df[\"Summarized_label\"] = None\n", "\n", "for index, row in result_df.iterrows():\n", " value = row[\"Category_single\"]\n", " for key, values in new_labels_dict.items():\n", " if value in values:\n", " result_df.at[index, \"Summarized_label\"] = key\n", " break # No need to check other keys if match found\n", "result_df" ] }, { "cell_type": "markdown", "id": "607a0996", "metadata": {}, "source": [ "## Naive Bayes Model" ] }, { "cell_type": "code", "execution_count": 21, "id": "b8c331bd", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "\n", "# from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.naive_bayes import MultinomialNB\n", "from sklearn.metrics import accuracy_score, classification_report" ] }, { "cell_type": "code", "execution_count": 22, "id": "ca8d53af", "metadata": {}, "outputs": [], "source": [ "X = result_df[\"Details_cleaned\"]\n", "y = result_df[\"Summarized_label\"]" ] }, { "cell_type": "code", "execution_count": 23, "id": "432e793e", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "119b6c46", "metadata": {}, "outputs": [], "source": [ "# vectorizer = CountVectorizer()\n", "# X_train_vec = vectorizer.fit_transform(X_train)\n", "# X_test_vec = vectorizer.transform(X_test)\n", "\n", "tfidf_vectorizer = TfidfVectorizer(max_features=1000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 25, "id": "18cf6e8e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
MultinomialNB()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "MultinomialNB()" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "naive_bayes = MultinomialNB()\n", "naive_bayes.fit(X_train_tfidf, y_train)" ] }, { "cell_type": "code", "execution_count": 26, "id": "4e4d6e2e", "metadata": {}, "outputs": [], "source": [ "predictions = naive_bayes.predict(X_test_tfidf)" ] }, { "cell_type": "code", "execution_count": 27, "id": "abd1d4a6", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Naive Bayes model: 0.763840830449827\n", " precision recall f1-score support\n", "\n", " Accident 0.71 0.74 0.72 129\n", "Administrative Issue 0.83 0.89 0.86 662\n", " Cyber Attack 0.00 0.00 0.00 4\n", " Human Error 0.00 0.00 0.00 18\n", " Others 0.41 0.24 0.30 79\n", " Terrorism 0.42 0.15 0.23 52\n", " Weather 0.77 0.92 0.84 92\n", " Worker Strike 0.61 0.69 0.65 120\n", "\n", " accuracy 0.76 1156\n", " macro avg 0.47 0.46 0.45 1156\n", " weighted avg 0.73 0.76 0.74 1156\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] } ], "source": [ "accuracy = accuracy_score(y_test, predictions)\n", "print(\"Accuracy of Naive Bayes model:\", accuracy)\n", "print(classification_report(y_test, predictions))" ] }, { "cell_type": "markdown", "id": "0bb9d98b", "metadata": {}, "source": [ "Find the optimal Alpha parameter" ] }, { "cell_type": "code", "execution_count": 28, "id": "f4eead05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best Alpha: 0.1\n" ] }, { "data": { "text/html": [ "
MultinomialNB(alpha=0.1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "MultinomialNB(alpha=0.1)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "param_grid = {\"alpha\": [0.1, 0.5, 1.0, 2.0]}\n", "\n", "# Initialize the grid search\n", "grid_search = GridSearchCV(MultinomialNB(), param_grid, cv=5, scoring=\"accuracy\")\n", "\n", "# Perform the grid search\n", "grid_search.fit(X_train_tfidf, y_train)\n", "\n", "# Get the best hyperparameters\n", "best_alpha = grid_search.best_params_[\"alpha\"]\n", "print(\"Best Alpha:\", best_alpha)\n", "\n", "# Train the model with the best alpha\n", "naive_bayes_tuned = MultinomialNB(alpha=best_alpha)\n", "naive_bayes_tuned.fit(X_train_tfidf, y_train)" ] }, { "cell_type": "markdown", "id": "5c747eab", "metadata": {}, "source": [ "Change the Alpha to 0.1 and max_features to 4000 for better performance" ] }, { "cell_type": "code", "execution_count": 29, "id": "71d0742f", "metadata": {}, "outputs": [], "source": [ "import time" ] }, { "cell_type": "code", "execution_count": 30, "id": "b22c1073", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Naive Bayes model: 0.7923875432525952\n", " precision recall f1-score support\n", "\n", " Accident 0.74 0.84 0.79 129\n", "Administrative Issue 0.89 0.87 0.88 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.67 0.22 0.33 18\n", " Others 0.45 0.35 0.40 79\n", " Terrorism 0.54 0.40 0.46 52\n", " Weather 0.77 0.93 0.85 92\n", " Worker Strike 0.65 0.75 0.69 120\n", "\n", " accuracy 0.79 1156\n", " macro avg 0.71 0.58 0.60 1156\n", " weighted avg 0.79 0.79 0.79 1156\n", "\n", "Total Runtime: 0.11717486381530762\n" ] } ], "source": [ "X = result_df[\"Details_cleaned\"]\n", "y = result_df[\"Summarized_label\"]\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "start_time = time.time()\n", "tfidf_vectorizer = TfidfVectorizer(max_features=4000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)\n", "\n", "naive_bayes = MultinomialNB(alpha=0.1)\n", "naive_bayes.fit(X_train_tfidf, y_train)\n", "\n", "predictions = naive_bayes.predict(X_test_tfidf)\n", "\n", "end_time = time.time()\n", "total_runtime = end_time - start_time\n", "\n", "accuracy = accuracy_score(y_test, predictions)\n", "print(\"Accuracy of Naive Bayes model:\", accuracy)\n", "print(classification_report(y_test, predictions))\n", "\n", "print(\"Total Runtime:\", total_runtime)" ] }, { "cell_type": "markdown", "id": "aa011ad5", "metadata": {}, "source": [ "## Logistic Regression model" ] }, { "cell_type": "code", "execution_count": 31, "id": "6e735f18", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 32, "id": "e266616c", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "id": "b1314e98", "metadata": {}, "outputs": [], "source": [ "tfidf_vectorizer = TfidfVectorizer(max_features=1000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 34, "id": "87905c28", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression()" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = LogisticRegression()\n", "model.fit(X_train_tfidf, y_train)" ] }, { "cell_type": "code", "execution_count": 35, "id": "c4bf008a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Logistic Regression Model: 0.7975778546712803\n", " precision recall f1-score support\n", "\n", " Accident 0.79 0.81 0.80 129\n", "Administrative Issue 0.83 0.93 0.88 662\n", " Cyber Attack 0.00 0.00 0.00 4\n", " Human Error 0.00 0.00 0.00 18\n", " Others 0.64 0.34 0.45 79\n", " Terrorism 0.46 0.21 0.29 52\n", " Weather 0.83 0.87 0.85 92\n", " Worker Strike 0.69 0.71 0.70 120\n", "\n", " accuracy 0.80 1156\n", " macro avg 0.53 0.48 0.50 1156\n", " weighted avg 0.77 0.80 0.78 1156\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] } ], "source": [ "y_pred = model.predict(X_test_tfidf)\n", "\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of Logistic Regression Model:\", accuracy)\n", "print(classification_report(y_test, y_pred))" ] }, { "cell_type": "code", "execution_count": 36, "id": "69b1b25a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Best Parameters: {'model__C': 10.0, 'tfidf__max_features': 2000}\n", "Accuracy of Tuned Logistic Regression Model: 0.8200692041522492\n", " precision recall f1-score support\n", "\n", " Accident 0.81 0.86 0.83 129\n", "Administrative Issue 0.86 0.91 0.88 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.60 0.17 0.26 18\n", " Others 0.61 0.43 0.50 79\n", " Terrorism 0.61 0.44 0.51 52\n", " Weather 0.87 0.90 0.89 92\n", " Worker Strike 0.73 0.75 0.74 120\n", "\n", " accuracy 0.82 1156\n", " macro avg 0.76 0.59 0.63 1156\n", " weighted avg 0.81 0.82 0.81 1156\n", "\n" ] } ], "source": [ "from sklearn.pipeline import Pipeline\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "param_grid = {\n", " \"tfidf__max_features\": [500, 1000, 2000, 3000, 4000],\n", " \"model__C\": [0.1, 1.0, 10.0],\n", "}\n", "\n", "pipeline = Pipeline([(\"tfidf\", TfidfVectorizer()), (\"model\", LogisticRegression())])\n", "\n", "grid_search = GridSearchCV(pipeline, param_grid, cv=5, scoring=\"accuracy\")\n", "\n", "grid_search.fit(X_train, y_train)\n", "\n", "best_params = grid_search.best_params_\n", "print(\"Best Parameters:\", best_params)\n", "\n", "best_model = grid_search.best_estimator_\n", "best_model.fit(X_train, y_train)\n", "\n", "y_pred = best_model.predict(X_test)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of Tuned Logistic Regression Model:\", accuracy)\n", "print(classification_report(y_test, y_pred))" ] }, { "cell_type": "markdown", "id": "c74436a2", "metadata": {}, "source": [ "The best parameters are 'model__C': 10.0, 'tfidf__max_features': 2000" ] }, { "cell_type": "code", "execution_count": 37, "id": "7d7e7e31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Logistic Regression Model: 0.8200692041522492\n", " precision recall f1-score support\n", "\n", " Accident 0.81 0.86 0.83 129\n", "Administrative Issue 0.86 0.91 0.88 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.60 0.17 0.26 18\n", " Others 0.61 0.43 0.50 79\n", " Terrorism 0.61 0.44 0.51 52\n", " Weather 0.87 0.90 0.89 92\n", " Worker Strike 0.73 0.75 0.74 120\n", "\n", " accuracy 0.82 1156\n", " macro avg 0.76 0.59 0.63 1156\n", " weighted avg 0.81 0.82 0.81 1156\n", "\n", "Total Runtime: 0.3288562297821045\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "start_time = time.time()\n", "tfidf_vectorizer = TfidfVectorizer(max_features=2000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)\n", "\n", "model = LogisticRegression(C=10.0)\n", "model.fit(X_train_tfidf, y_train)\n", "\n", "y_pred = model.predict(X_test_tfidf)\n", "\n", "end_time = time.time()\n", "total_runtime = end_time - start_time\n", "\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of Logistic Regression Model:\", accuracy)\n", "print(classification_report(y_test, y_pred))\n", "\n", "print(\"Total Runtime:\", total_runtime)" ] }, { "cell_type": "markdown", "id": "482d0503", "metadata": {}, "source": [ "## Support Vector Machine (SVM) model" ] }, { "cell_type": "code", "execution_count": 38, "id": "9a2b2117", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.svm import SVC\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 39, "id": "f8e29f39", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "id": "246cca7a", "metadata": {}, "outputs": [], "source": [ "tfidf_vectorizer = TfidfVectorizer(max_features=1000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 41, "id": "393b87b3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
SVC(kernel='linear')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "SVC(kernel='linear')" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svm_model = SVC(kernel=\"linear\")\n", "svm_model.fit(X_train_tfidf, y_train)" ] }, { "cell_type": "code", "execution_count": 42, "id": "fc25cdcf", "metadata": {}, "outputs": [], "source": [ "y_pred = svm_model.predict(X_test_tfidf)" ] }, { "cell_type": "code", "execution_count": 43, "id": "2960279a", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of SVM model: 0.8183391003460208\n", " precision recall f1-score support\n", "\n", " Accident 0.78 0.82 0.80 129\n", "Administrative Issue 0.87 0.92 0.89 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.67 0.11 0.19 18\n", " Others 0.62 0.42 0.50 79\n", " Terrorism 0.55 0.31 0.40 52\n", " Weather 0.82 0.90 0.86 92\n", " Worker Strike 0.72 0.80 0.76 120\n", "\n", " accuracy 0.82 1156\n", " macro avg 0.75 0.57 0.60 1156\n", " weighted avg 0.81 0.82 0.80 1156\n", "\n" ] } ], "source": [ "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of SVM model:\", accuracy)\n", "print(classification_report(y_test, y_pred))" ] }, { "cell_type": "code", "execution_count": 44, "id": "4e9fee70", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best C: 10\n" ] } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "param_grid = {\"C\": [0.1, 1, 10]}\n", "svm = SVC()\n", "grid_search = GridSearchCV(svm, param_grid, cv=5, scoring=\"accuracy\")\n", "grid_search.fit(X_train_tfidf, y_train)\n", "best_c = grid_search.best_params_[\"C\"]\n", "print(\"Best C:\", best_c)" ] }, { "cell_type": "code", "execution_count": 45, "id": "65fd932b-63e8-4041-b7aa-0fae14e48efe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of SVM model: 0.782871972318339\n", " precision recall f1-score support\n", "\n", " Accident 0.72 0.84 0.77 129\n", "Administrative Issue 0.86 0.86 0.86 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.62 0.28 0.38 18\n", " Others 0.51 0.46 0.48 79\n", " Terrorism 0.49 0.38 0.43 52\n", " Weather 0.81 0.87 0.84 92\n", " Worker Strike 0.69 0.69 0.69 120\n", "\n", " accuracy 0.78 1156\n", " macro avg 0.71 0.58 0.61 1156\n", " weighted avg 0.78 0.78 0.78 1156\n", "\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "tfidf_vectorizer = TfidfVectorizer(max_features=1000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)\n", "\n", "svm_model = SVC(kernel=\"linear\", C=10)\n", "svm_model.fit(X_train_tfidf, y_train)\n", "\n", "y_pred = svm_model.predict(X_test_tfidf)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of SVM model:\", accuracy)\n", "print(classification_report(y_test, y_pred))" ] }, { "cell_type": "markdown", "id": "a2843fa9", "metadata": {}, "source": [ "But when C is set to 10, the accuracy drops, it may be due to overfitting. We will still use the defaul value C=1.0" ] }, { "cell_type": "code", "execution_count": 46, "id": "afffe960", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of SVM model: 0.8217993079584776\n", " precision recall f1-score support\n", "\n", " Accident 0.82 0.86 0.84 129\n", "Administrative Issue 0.86 0.93 0.89 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.00 0.00 0.00 18\n", " Others 0.64 0.41 0.50 79\n", " Terrorism 0.61 0.33 0.42 52\n", " Weather 0.83 0.90 0.86 92\n", " Worker Strike 0.71 0.77 0.74 120\n", "\n", " accuracy 0.82 1156\n", " macro avg 0.68 0.55 0.58 1156\n", " weighted avg 0.80 0.82 0.81 1156\n", "\n", "Total Runtime: 3.1480040550231934\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "start_time = time.time()\n", "tfidf_vectorizer = TfidfVectorizer(max_features=2000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)\n", "\n", "svm_model = SVC(kernel=\"linear\")\n", "svm_model.fit(X_train_tfidf, y_train)\n", "\n", "y_pred = svm_model.predict(X_test_tfidf)\n", "\n", "end_time = time.time()\n", "total_runtime = end_time - start_time\n", "\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of SVM model:\", accuracy)\n", "print(classification_report(y_test, y_pred))\n", "\n", "print(\"Total Runtime:\", total_runtime)" ] }, { "cell_type": "markdown", "id": "deac9dd7", "metadata": {}, "source": [ "## Random Forest Model" ] }, { "cell_type": "code", "execution_count": 47, "id": "fba3d3c4", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 48, "id": "390399c2", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 49, "id": "74d99fe7", "metadata": {}, "outputs": [], "source": [ "tfidf_vectorizer = TfidfVectorizer(max_features=1000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 50, "id": "f37ceeae", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier(random_state=42)" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rf_model = RandomForestClassifier(n_estimators=100, random_state=42)\n", "rf_model.fit(X_train_tfidf, y_train)" ] }, { "cell_type": "code", "execution_count": 51, "id": "51cbc1c4", "metadata": {}, "outputs": [], "source": [ "y_pred = rf_model.predict(X_test_tfidf)" ] }, { "cell_type": "code", "execution_count": 52, "id": "688925b0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Random Forest Model: 0.801038062283737\n", " precision recall f1-score support\n", "\n", " Accident 0.77 0.80 0.79 129\n", "Administrative Issue 0.84 0.92 0.88 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.50 0.06 0.10 18\n", " Others 0.72 0.39 0.51 79\n", " Terrorism 0.67 0.19 0.30 52\n", " Weather 0.79 0.86 0.82 92\n", " Worker Strike 0.66 0.77 0.71 120\n", "\n", " accuracy 0.80 1156\n", " macro avg 0.74 0.53 0.56 1156\n", " weighted avg 0.79 0.80 0.78 1156\n", "\n" ] } ], "source": [ "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of Random Forest Model:\", accuracy)\n", "print(classification_report(y_test, y_pred))" ] }, { "cell_type": "markdown", "id": "4b919b55", "metadata": {}, "source": [ "Fine tuning by adjusting the hyperparamters. After testing on the hyperparameters, below are the best parameters for this model." ] }, { "cell_type": "code", "execution_count": 53, "id": "6b4868ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of Random Forest Model: 0.8070934256055363\n", " precision recall f1-score support\n", "\n", " Accident 0.80 0.79 0.80 129\n", "Administrative Issue 0.83 0.94 0.88 662\n", " Cyber Attack 1.00 0.25 0.40 4\n", " Human Error 0.50 0.06 0.10 18\n", " Others 0.74 0.41 0.52 79\n", " Terrorism 0.86 0.12 0.20 52\n", " Weather 0.82 0.85 0.83 92\n", " Worker Strike 0.67 0.78 0.72 120\n", "\n", " accuracy 0.81 1156\n", " macro avg 0.78 0.52 0.56 1156\n", " weighted avg 0.80 0.81 0.78 1156\n", "\n", "Total Runtime: 2.4808011054992676\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "start_time = time.time()\n", "tfidf_vectorizer = TfidfVectorizer(max_features=2000)\n", "X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n", "X_test_tfidf = tfidf_vectorizer.transform(X_test)\n", "\n", "rf_model = RandomForestClassifier(\n", " n_estimators=300, min_samples_split=5, random_state=42\n", ")\n", "rf_model.fit(X_train_tfidf, y_train)\n", "\n", "y_pred = rf_model.predict(X_test_tfidf)\n", "end_time = time.time()\n", "total_runtime = end_time - start_time\n", "\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy of Random Forest Model:\", accuracy)\n", "print(classification_report(y_test, y_pred))\n", "\n", "print(\"Total Runtime:\", total_runtime)" ] }, { "cell_type": "markdown", "id": "7df52b09", "metadata": {}, "source": [ "### KNN" ] }, { "cell_type": "code", "execution_count": 54, "id": "b8822f38", "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 55, "id": "368a2dd1", "metadata": {}, "outputs": [], "source": [ "vectorizer = TfidfVectorizer(max_features=2000)\n", "X = vectorizer.fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 56, "id": "ae8bae0b", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 57, "id": "3ef3809f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.7889273356401384\n" ] } ], "source": [ "# Step 4: Apply KNN Algorithm\n", "k = 5 # Number of neighbors\n", "knn_model = KNeighborsClassifier(n_neighbors=k)\n", "knn_model.fit(X_train, y_train)\n", "\n", "# Step 5: Make Predictions and Evaluate Performance\n", "y_pred = knn_model.predict(X_test)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy:\", accuracy)" ] }, { "cell_type": "markdown", "id": "cca13522-7877-4f1f-9b15-d8fb6fd55d12", "metadata": {}, "source": [ "Plot the model's performance against values of k to find the optimal k" ] }, { "cell_type": "code", "execution_count": 58, "id": "67102a37-2286-442f-b270-d3c00614dd9c", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "k_values = range(1, 21)\n", "\n", "train_scores = []\n", "test_scores = []\n", "\n", "# Iterate over each k value\n", "for k in k_values:\n", " # Train KNN classifier\n", " knn = KNeighborsClassifier(n_neighbors=k)\n", " knn.fit(X_train, y_train)\n", "\n", " # Calculate training and testing accuracy\n", " train_score = knn.score(X_train, y_train)\n", " test_score = knn.score(X_test, y_test)\n", "\n", " train_scores.append(train_score)\n", " test_scores.append(test_score)\n", "\n", "# Plot the performance scores\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(k_values, train_scores, label=\"Train Accuracy\", marker=\"o\")\n", "plt.plot(k_values, test_scores, label=\"Test Accuracy\", marker=\"o\")\n", "plt.xlabel(\"Number of Neighbors (k)\")\n", "plt.ylabel(\"Accuracy\")\n", "plt.title(\"KNN Classifier Performance\")\n", "plt.xticks(np.arange(1, 21, step=1))\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "38fbfc18-9f0f-405f-952e-74725d3fb6ed", "metadata": {}, "source": [ "k=5 is an optimal value" ] }, { "cell_type": "markdown", "id": "f2e34a7b-6c6b-4308-874e-f74899336c61", "metadata": {}, "source": [ "Find other optimal hyperparameters by using grid search" ] }, { "cell_type": "code", "execution_count": 59, "id": "5725954c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: \n", "Traceback (most recent call last):\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 971, in _score\n", " scores = scorer(estimator, X_test, y_test, **score_params)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 279, in __call__\n", " return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 371, in _score\n", " y_pred = method_caller(\n", " ^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 89, in _cached_call\n", " result, _ = _get_response_values(\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/utils/_response.py\", line 211, in _get_response_values\n", " y_pred = prediction_method(X)\n", " ^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 259, in predict\n", " probabilities = self.predict_proba(X)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 343, in predict_proba\n", " probabilities = ArgKminClassMode.compute(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 590, in compute\n", " unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "ValueError: invalid literal for int() with base 10: 'Accident'\n", "\n", " warnings.warn(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: \n", "Traceback (most recent call last):\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 971, in _score\n", " scores = scorer(estimator, X_test, y_test, **score_params)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 279, in __call__\n", " return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 371, in _score\n", " y_pred = method_caller(\n", " ^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 89, in _cached_call\n", " result, _ = _get_response_values(\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/utils/_response.py\", line 211, in _get_response_values\n", " y_pred = prediction_method(X)\n", " ^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 259, in predict\n", " probabilities = self.predict_proba(X)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 343, in predict_proba\n", " probabilities = ArgKminClassMode.compute(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 590, in compute\n", " unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "ValueError: invalid literal for int() with base 10: 'Accident'\n", "\n", " warnings.warn(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: \n", "Traceback (most recent call last):\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 971, in _score\n", " scores = scorer(estimator, X_test, y_test, **score_params)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 279, in __call__\n", " return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 371, in _score\n", " y_pred = method_caller(\n", " ^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 89, in _cached_call\n", " result, _ = _get_response_values(\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/utils/_response.py\", line 211, in _get_response_values\n", " y_pred = prediction_method(X)\n", " ^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 259, in predict\n", " probabilities = self.predict_proba(X)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 343, in predict_proba\n", " probabilities = ArgKminClassMode.compute(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 590, in compute\n", " unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "ValueError: invalid literal for int() with base 10: 'Accident'\n", "\n", " warnings.warn(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: \n", "Traceback (most recent call last):\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 971, in _score\n", " scores = scorer(estimator, X_test, y_test, **score_params)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 279, in __call__\n", " return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 371, in _score\n", " y_pred = method_caller(\n", " ^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 89, in _cached_call\n", " result, _ = _get_response_values(\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/utils/_response.py\", line 211, in _get_response_values\n", " y_pred = prediction_method(X)\n", " ^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 259, in predict\n", " probabilities = self.predict_proba(X)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 343, in predict_proba\n", " probabilities = ArgKminClassMode.compute(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 590, in compute\n", " unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "ValueError: invalid literal for int() with base 10: 'Accident'\n", "\n", " warnings.warn(\n", "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: \n", "Traceback (most recent call last):\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 971, in _score\n", " scores = scorer(estimator, X_test, y_test, **score_params)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 279, in __call__\n", " return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 371, in _score\n", " y_pred = method_caller(\n", " ^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_scorer.py\", line 89, in _cached_call\n", " result, _ = _get_response_values(\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/utils/_response.py\", line 211, in _get_response_values\n", " y_pred = prediction_method(X)\n", " ^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 259, in predict\n", " probabilities = self.predict_proba(X)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/neighbors/_classification.py\", line 343, in predict_proba\n", " probabilities = ArgKminClassMode.compute(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 590, in compute\n", " unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp),\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "ValueError: invalid literal for int() with base 10: 'Accident'\n", "\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Best Parameters: {'p': 2, 'weights': 'distance'}\n", "Test Accuracy: 0.7993079584775087\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/inflaton/anaconda3/envs/maritime/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1052: UserWarning: One or more of the test scores are non-finite: [ nan 0.58348075 0.77465707 0.78136118]\n", " warnings.warn(\n" ] } ], "source": [ "knn = KNeighborsClassifier()\n", "\n", "param_grid = {\"weights\": [\"uniform\", \"distance\"], \"p\": [1, 2]}\n", "\n", "grid_search = GridSearchCV(\n", " estimator=knn, param_grid=param_grid, cv=5, scoring=\"accuracy\"\n", ")\n", "\n", "grid_search.fit(X_train, y_train)\n", "\n", "best_params = grid_search.best_params_\n", "\n", "best_model = grid_search.best_estimator_\n", "\n", "test_accuracy = best_model.score(X_test, y_test)\n", "print(\"Best Parameters:\", best_params)\n", "print(\"Test Accuracy:\", test_accuracy)" ] }, { "cell_type": "markdown", "id": "6b7449ea-16e7-4660-89c6-24fe635f2880", "metadata": {}, "source": [ "Lastly, run the model with optimal hyperparameters" ] }, { "cell_type": "code", "execution_count": 60, "id": "50fb3195-fe1c-499a-9157-0be8dc7be3e1", "metadata": {}, "outputs": [], "source": [ "import time" ] }, { "cell_type": "code", "execution_count": 61, "id": "dbd33ce9-ebc7-42d8-a190-013c1d889286", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.7993079584775087\n", "Total Runtime: 0.10337376594543457\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "start_time = time.time()\n", "\n", "k = 5\n", "knn_model = KNeighborsClassifier(n_neighbors=k, weights=\"distance\")\n", "knn_model.fit(X_train, y_train)\n", "\n", "y_pred = knn_model.predict(X_test)\n", "\n", "end_time = time.time()\n", "total_runtime = end_time - start_time\n", "\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy:\", accuracy)\n", "print(\"Total Runtime:\", total_runtime)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }