{
"cells": [
{
"metadata": {
"id": "_Ly997V5FRvJ"
},
"cell_type": "markdown",
"source": [
"# Implementation of text classification with BERT\n",
"\n",
"\n",
"Still Working on it.\n",
"\n",
"This notebook is based in this TensorFlow tutorial: [Classify text with BERT](https://www.tensorflow.org/tutorials/text/classify_text_with_bert)\n",
"\n",
"BERT [(article link)](https://arxiv.org/abs/1810.04805) and other Transformer encoder architectures have been wildly successful on a variety of tasks in NLP (natural language processing). They compute vector-space representations of natural language that are suitable for use in deep learning models.\n",
"\n",
"![](http://www.d2l.ai/_images/nlp-map-pretrain.svg)\n",
"\n",
"Source: http://www.d2l.ai/chapter_natural-language-processing-pretraining/index.html\n",
"\n",
"BERT models are usually pre-trained on a large corpus of text, then fine-tuned for specific tasks.\n",
"\n",
"In this notebook, I am going to use a pretreined BERT to compute vector-space representations of a hate speech dataset to feed two different downsteam Archtectures (CNN and MLP).\n",
"\n",
"Sentiment Analysis\n",
"\n",
"This notebook trains a sentiment analysis model to classify the [Hate Speech and Offensive Language Dataset]( https://www.kaggle.com/mrmorj/hate-speech-and-offensive-language-dataset) tweets in three classes:\n",
" \n",
"* 0 - hate speech \n",
"* 1 - offensive language \n",
"* 2 - neither as positive or negative"
]
},
{
"metadata": {
"id": "0E1ATVOAFRvL"
},
"cell_type": "markdown",
"source": [
"## Installing dependencies and importing packages"
]
},
{
"metadata": {
"trusted": true,
"id": "m_AYtaEKFRvN"
},
"cell_type": "code",
"source": [
"# A dependency of the preprocessing for BERT inputs\n",
"!pip install -q tensorflow-text > /dev/null\n",
"!pip install -q tf-models-official > /dev/null\n",
"!pip install -q transformers > /dev/null"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# this jupyter was running in collab\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zEiku7sUB2Bo",
"outputId": "94116144-847a-4d33-aa61-29ab3597b83e"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"id": "mehF1PbZFRvM"
},
"cell_type": "code",
"source": [
"# This Python 3 environment comes with many helpful analytics libraries installed\n",
"# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
"# For example, here's several helpful packages to load\n",
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Input data files are available in the read-only \"../input/\" directory\n",
"# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
"import os\n",
"for dirname, _, filenames in os.walk('/kaggle/input'):\n",
" for filename in filenames:\n",
" print(os.path.join(dirname, filename))\n",
"\n",
"# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
"# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "pP-P1A0nFRvO",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "460653d2-1650-47c5-c756-3a7f5d4c8f05"
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import GroupKFold\n",
"import matplotlib.pyplot as plt\n",
"from tqdm.notebook import tqdm\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import tensorflow_text as text\n",
"import tensorflow.keras.backend as K\n",
"from scipy.stats import spearmanr\n",
"from math import floor, ceil\n",
"from transformers import *\n",
"\n",
"np.set_printoptions(suppress=True)\n",
"print(tf.__version__)\n",
"\n",
"import shutil\n",
"\n",
"from official.nlp import optimization # to create AdamW optmizer\n",
"\n",
"tf.get_logger().setLevel('ERROR')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.11.0\n"
]
}
]
},
{
"metadata": {
"id": "cfU6jpbDFRvO"
},
"cell_type": "markdown",
"source": [
"## Reading and preparing the dataset"
]
},
{
"cell_type": "code",
"source": [
"# replace username and key\n",
"!KAGGLE_USERNAME=xxx KAGGLE_KEY=xxx kaggle datasets download -d mrmorj/hate-speech-and-offensive-language-dataset"
],
"metadata": {
"id": "6pnu6cmKMBCi",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "48abb781-4d94-4e40-ab0d-9411d1621189"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"hate-speech-and-offensive-language-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!mkdir -p /input/hate-speech-and-offensive-language-dataset/\n",
"!unzip -o /content/hate-speech-and-offensive-language-dataset.zip -d /input/hate-speech-and-offensive-language-dataset/"
],
"metadata": {
"id": "COTWNFw3RYtq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "8f3174ce-c298-402c-ce35-040578b75fcc"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Archive: /content/hate-speech-and-offensive-language-dataset.zip\n",
" inflating: /input/hate-speech-and-offensive-language-dataset/labeled_data.csv \n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "rkrreVusFRvO"
},
"cell_type": "code",
"source": [
"PATH = '../input/hate-speech-and-offensive-language-dataset/'\n",
"df = pd.read_csv(PATH+'labeled_data.csv')"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "TY4lZ3O5FRvO",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1661a1f1-7490-4d15-f6e2-67d412a0fff3"
},
"cell_type": "code",
"source": [
"nRowsRead = None # specify 'None' if want to read whole file\n",
"# labeled_data.csv may have more rows in reality, but we are only loading/previewing the first 1000 rows\n",
"df0 = pd.read_csv('../input/hate-speech-and-offensive-language-dataset/labeled_data.csv', delimiter=',', nrows = nRowsRead)\n",
"df0.dataframeName = 'labeled_data.csv'\n",
"nRow, nCol = df0.shape\n",
"print('There are {} rows and {} columns'.format(nRow, nCol))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"There are 24783 rows and 7 columns\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "fMTsmoD9FRvO",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "957ed7d5-f5cc-45e8-d68c-33e74ff58af8"
},
"cell_type": "code",
"source": [
"df0.head(5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Unnamed: 0 count hate_speech offensive_language neither class \\\n",
"0 0 3 0 0 3 2 \n",
"1 1 3 0 3 0 1 \n",
"2 2 3 0 3 0 1 \n",
"3 3 3 0 2 1 1 \n",
"4 4 6 0 6 0 1 \n",
"\n",
" tweet \n",
"0 !!! RT @mayasolovely: As a woman you shouldn't... \n",
"1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n",
"2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n",
"3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n",
"4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... "
],
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Unnamed: 0 \n",
" count \n",
" hate_speech \n",
" offensive_language \n",
" neither \n",
" class \n",
" tweet \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0 \n",
" 3 \n",
" 0 \n",
" 0 \n",
" 3 \n",
" 2 \n",
" !!! RT @mayasolovely: As a woman you shouldn't... \n",
" \n",
" \n",
" 1 \n",
" 1 \n",
" 3 \n",
" 0 \n",
" 3 \n",
" 0 \n",
" 1 \n",
" !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n",
" \n",
" \n",
" 2 \n",
" 2 \n",
" 3 \n",
" 0 \n",
" 3 \n",
" 0 \n",
" 1 \n",
" !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n",
" \n",
" \n",
" 3 \n",
" 3 \n",
" 3 \n",
" 0 \n",
" 2 \n",
" 1 \n",
" 1 \n",
" !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n",
" \n",
" \n",
" 4 \n",
" 4 \n",
" 6 \n",
" 0 \n",
" 6 \n",
" 0 \n",
" 1 \n",
" !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 44
}
]
},
{
"metadata": {
"trusted": true,
"id": "DYA58oAyFRvP",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"outputId": "ef4e855b-051e-4a33-cb4e-c4989db11b2c"
},
"cell_type": "code",
"source": [
"#Doing some adjustments\n",
"\n",
"c=df0['class']\n",
"df0.rename(columns={'tweet' : 'text',\n",
" 'class' : 'category'}, \n",
" inplace=True)\n",
"a=df0['text']\n",
"b=df0['category'].map({0: 'hate_speech', 1: 'offensive_language',2: 'neither'})\n",
"\n",
"df= pd.concat([a,b,c], axis=1)\n",
"df.rename(columns={'class' : 'label'}, \n",
" inplace=True)\n",
"df"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category \\\n",
"0 !!! RT @mayasolovely: As a woman you shouldn't... neither \n",
"1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... offensive_language \n",
"2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... offensive_language \n",
"3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... offensive_language \n",
"4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... offensive_language \n",
"... ... ... \n",
"24778 you's a muthaf***in lie “@LifeAsKing: @2... offensive_language \n",
"24779 you've gone and broke the wrong heart baby, an... neither \n",
"24780 young buck wanna eat!!.. dat nigguh like I ain... offensive_language \n",
"24781 youu got wild bitches tellin you lies offensive_language \n",
"24782 ~~Ruffled | Ntac Eileen Dahlia - Beautiful col... neither \n",
"\n",
" label \n",
"0 2 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 \n",
"... ... \n",
"24778 1 \n",
"24779 2 \n",
"24780 1 \n",
"24781 1 \n",
"24782 2 \n",
"\n",
"[24783 rows x 3 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" !!! RT @mayasolovely: As a woman you shouldn't... \n",
" neither \n",
" 2 \n",
" \n",
" \n",
" 1 \n",
" !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 2 \n",
" !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 3 \n",
" !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 4 \n",
" !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 24778 \n",
" you's a muthaf***in lie “@LifeAsKing: @2... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 24779 \n",
" you've gone and broke the wrong heart baby, an... \n",
" neither \n",
" 2 \n",
" \n",
" \n",
" 24780 \n",
" young buck wanna eat!!.. dat nigguh like I ain... \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 24781 \n",
" youu got wild bitches tellin you lies \n",
" offensive_language \n",
" 1 \n",
" \n",
" \n",
" 24782 \n",
" ~~Ruffled | Ntac Eileen Dahlia - Beautiful col... \n",
" neither \n",
" 2 \n",
" \n",
" \n",
"
\n",
"
24783 rows × 3 columns
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 45
}
]
},
{
"metadata": {
"trusted": true,
"id": "6JhjkvfrFRvP",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
},
"outputId": "33953f56-8ff2-4eb6-d199-eb4fb8732ad2"
},
"cell_type": "code",
"source": [
"# Grouping data by label\n",
"df.groupby('label').count()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category\n",
"label \n",
"0 1430 1430\n",
"1 19190 19190\n",
"2 4163 4163"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" \n",
" \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 1430 \n",
" 1430 \n",
" \n",
" \n",
" 1 \n",
" 19190 \n",
" 19190 \n",
" \n",
" \n",
" 2 \n",
" 4163 \n",
" 4163 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 46
}
]
},
{
"metadata": {
"id": "PS4dSjrvFRvP"
},
"cell_type": "markdown",
"source": [
"This is an unbalanced dataset. "
]
},
{
"metadata": {
"trusted": true,
"id": "b9YYGHIZFRvP",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "08ec3b2e-2e67-4f02-89da-2ba0b76f85c2"
},
"cell_type": "code",
"source": [
"hate, ofensive, neither = np.bincount(df['label'])\n",
"total = hate + ofensive + neither\n",
"print('Examples:\\n Total: {}\\n hate: {} ({:.2f}% of total)\\n'.format(\n",
" total, hate, 100 * hate / total))\n",
"print('Examples:\\n Total: {}\\n Ofensive: {} ({:.2f}% of total)\\n'.format(\n",
" total, ofensive, 100 * ofensive / total))\n",
"print('Examples:\\n Total: {}\\n Neither: {} ({:.2f}% of total)\\n'.format(\n",
" total, neither, 100 * neither / total))\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Examples:\n",
" Total: 24783\n",
" hate: 1430 (5.77% of total)\n",
"\n",
"Examples:\n",
" Total: 24783\n",
" Ofensive: 19190 (77.43% of total)\n",
"\n",
"Examples:\n",
" Total: 24783\n",
" Neither: 4163 (16.80% of total)\n",
"\n"
]
}
]
},
{
"metadata": {
"id": "DviRMu0yFRvQ"
},
"cell_type": "markdown",
"source": [
"### Splitting the data between train, validation and test sets:"
]
},
{
"metadata": {
"trusted": true,
"id": "HGn5vsjHFRvQ"
},
"cell_type": "code",
"source": [
"X_train_, X_test, y_train_, y_test = train_test_split(\n",
" df.index.values,\n",
" df.label.values,\n",
" test_size=0.10,\n",
" random_state=42,\n",
" stratify=df.label.values, \n",
")"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "nLFmmeBHFRvQ"
},
"cell_type": "code",
"source": [
"X_train, X_val, y_train, y_val = train_test_split(\n",
" df.loc[X_train_].index.values,\n",
" df.loc[X_train_].label.values,\n",
" test_size=0.10,\n",
" random_state=42,\n",
" stratify=df.loc[X_train_].label.values, \n",
")"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "MoWyqeZIFRvQ"
},
"cell_type": "code",
"source": [
"df['data_type'] = ['not_set']*df.shape[0]\n",
"df.loc[X_train, 'data_type'] = 'train'\n",
"df.loc[X_val, 'data_type'] = 'val'\n",
"df.loc[X_test, 'data_type'] = 'test'"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "iahBQXCtFRvQ",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"outputId": "441c5d31-e8d3-4fab-da3b-680c73b00ed8"
},
"cell_type": "code",
"source": [
"df.groupby(['category', 'label', 'data_type']).count()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text\n",
"category label data_type \n",
"hate_speech 0 test 143\n",
" train 1158\n",
" val 129\n",
"neither 2 test 416\n",
" train 3372\n",
" val 375\n",
"offensive_language 1 test 1920\n",
" train 15543\n",
" val 1727"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" text \n",
" \n",
" \n",
" category \n",
" label \n",
" data_type \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" hate_speech \n",
" 0 \n",
" test \n",
" 143 \n",
" \n",
" \n",
" train \n",
" 1158 \n",
" \n",
" \n",
" val \n",
" 129 \n",
" \n",
" \n",
" neither \n",
" 2 \n",
" test \n",
" 416 \n",
" \n",
" \n",
" train \n",
" 3372 \n",
" \n",
" \n",
" val \n",
" 375 \n",
" \n",
" \n",
" offensive_language \n",
" 1 \n",
" test \n",
" 1920 \n",
" \n",
" \n",
" train \n",
" 15543 \n",
" \n",
" \n",
" val \n",
" 1727 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 51
}
]
},
{
"metadata": {
"trusted": true,
"id": "0YgYtV1bFRvQ",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"outputId": "d1a72299-0ef0-4847-f5f5-499c65ce0d2a"
},
"cell_type": "code",
"source": [
"df"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category \\\n",
"0 !!! RT @mayasolovely: As a woman you shouldn't... neither \n",
"1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... offensive_language \n",
"2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... offensive_language \n",
"3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... offensive_language \n",
"4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... offensive_language \n",
"... ... ... \n",
"24778 you's a muthaf***in lie “@LifeAsKing: @2... offensive_language \n",
"24779 you've gone and broke the wrong heart baby, an... neither \n",
"24780 young buck wanna eat!!.. dat nigguh like I ain... offensive_language \n",
"24781 youu got wild bitches tellin you lies offensive_language \n",
"24782 ~~Ruffled | Ntac Eileen Dahlia - Beautiful col... neither \n",
"\n",
" label data_type \n",
"0 2 test \n",
"1 1 train \n",
"2 1 train \n",
"3 1 train \n",
"4 1 train \n",
"... ... ... \n",
"24778 1 train \n",
"24779 2 train \n",
"24780 1 train \n",
"24781 1 train \n",
"24782 2 train \n",
"\n",
"[24783 rows x 4 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" label \n",
" data_type \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" !!! RT @mayasolovely: As a woman you shouldn't... \n",
" neither \n",
" 2 \n",
" test \n",
" \n",
" \n",
" 1 \n",
" !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 2 \n",
" !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 3 \n",
" !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 4 \n",
" !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 24778 \n",
" you's a muthaf***in lie “@LifeAsKing: @2... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 24779 \n",
" you've gone and broke the wrong heart baby, an... \n",
" neither \n",
" 2 \n",
" train \n",
" \n",
" \n",
" 24780 \n",
" young buck wanna eat!!.. dat nigguh like I ain... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 24781 \n",
" youu got wild bitches tellin you lies \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 24782 \n",
" ~~Ruffled | Ntac Eileen Dahlia - Beautiful col... \n",
" neither \n",
" 2 \n",
" train \n",
" \n",
" \n",
"
\n",
"
24783 rows × 4 columns
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 52
}
]
},
{
"metadata": {
"trusted": true,
"id": "sBl6CDIrFRvR",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "f6bd6032-5f67-464a-8328-017b579ba60a"
},
"cell_type": "code",
"source": [
"df_train = df.loc[df[\"data_type\"]==\"train\"]\n",
"df_train.head(5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category \\\n",
"1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... offensive_language \n",
"2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... offensive_language \n",
"3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... offensive_language \n",
"4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... offensive_language \n",
"6 !!!!!!\"@__BrighterDays: I can not just sit up ... offensive_language \n",
"\n",
" label data_type \n",
"1 1 train \n",
"2 1 train \n",
"3 1 train \n",
"4 1 train \n",
"6 1 train "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" label \n",
" data_type \n",
" \n",
" \n",
" \n",
" \n",
" 1 \n",
" !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 2 \n",
" !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 3 \n",
" !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 4 \n",
" !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
" 6 \n",
" !!!!!!\"@__BrighterDays: I can not just sit up ... \n",
" offensive_language \n",
" 1 \n",
" train \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 53
}
]
},
{
"metadata": {
"trusted": true,
"id": "wbv2v7ZNFRvR",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "5f3177b5-a872-405c-e562-8bc07ddb58df"
},
"cell_type": "code",
"source": [
"df_val = df.loc[df[\"data_type\"]==\"val\"]\n",
"df_val.head(5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category \\\n",
"5 !!!!!!!!!!!!!!!!!!\"@T_Madison_x: The shit just... offensive_language \n",
"27 \" i met that pussy on Ocean Dr . i gave that p... offensive_language \n",
"31 \" i'd say im back to the old me but my old bit... offensive_language \n",
"44 \" post a picture of that pussy get 200 likes \" offensive_language \n",
"46 \" quick piece of pussy call it a drive by \" offensive_language \n",
"\n",
" label data_type \n",
"5 1 val \n",
"27 1 val \n",
"31 1 val \n",
"44 1 val \n",
"46 1 val "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" label \n",
" data_type \n",
" \n",
" \n",
" \n",
" \n",
" 5 \n",
" !!!!!!!!!!!!!!!!!!\"@T_Madison_x: The shit just... \n",
" offensive_language \n",
" 1 \n",
" val \n",
" \n",
" \n",
" 27 \n",
" \" i met that pussy on Ocean Dr . i gave that p... \n",
" offensive_language \n",
" 1 \n",
" val \n",
" \n",
" \n",
" 31 \n",
" \" i'd say im back to the old me but my old bit... \n",
" offensive_language \n",
" 1 \n",
" val \n",
" \n",
" \n",
" 44 \n",
" \" post a picture of that pussy get 200 likes \" \n",
" offensive_language \n",
" 1 \n",
" val \n",
" \n",
" \n",
" 46 \n",
" \" quick piece of pussy call it a drive by \" \n",
" offensive_language \n",
" 1 \n",
" val \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 54
}
]
},
{
"metadata": {
"trusted": true,
"id": "BK6-RPk6FRvR",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "0b2dcb3e-a65a-45bf-d9f9-46915e2a2409"
},
"cell_type": "code",
"source": [
"df_test = df.loc[df[\"data_type\"]==\"test\"]\n",
"df_test.head(5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text category \\\n",
"0 !!! RT @mayasolovely: As a woman you shouldn't... neither \n",
"12 \" So hoes that smoke are losers ? \" yea ... go... offensive_language \n",
"14 \" bitch get up off me \" offensive_language \n",
"17 \" bitch who do you love \" offensive_language \n",
"25 \" her pussy lips like Heaven doors \" 😌 offensive_language \n",
"\n",
" label data_type \n",
"0 2 test \n",
"12 1 test \n",
"14 1 test \n",
"17 1 test \n",
"25 1 test "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" category \n",
" label \n",
" data_type \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" !!! RT @mayasolovely: As a woman you shouldn't... \n",
" neither \n",
" 2 \n",
" test \n",
" \n",
" \n",
" 12 \n",
" \" So hoes that smoke are losers ? \" yea ... go... \n",
" offensive_language \n",
" 1 \n",
" test \n",
" \n",
" \n",
" 14 \n",
" \" bitch get up off me \" \n",
" offensive_language \n",
" 1 \n",
" test \n",
" \n",
" \n",
" 17 \n",
" \" bitch who do you love \" \n",
" offensive_language \n",
" 1 \n",
" test \n",
" \n",
" \n",
" 25 \n",
" \" her pussy lips like Heaven doors \" 😌 \n",
" offensive_language \n",
" 1 \n",
" test \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 55
}
]
},
{
"metadata": {
"trusted": true,
"id": "zcd2RCCnFRvR",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e0c2384c-a60a-4f81-e347-2592993ac5cc"
},
"cell_type": "code",
"source": [
"df.dtypes"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"text object\n",
"category object\n",
"label int64\n",
"data_type object\n",
"dtype: object"
]
},
"metadata": {},
"execution_count": 56
}
]
},
{
"metadata": {
"trusted": true,
"id": "Ox7Yu6i2FRvR",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 233
},
"outputId": "8890dc1d-3212-408f-a7e3-8cc96d08990f"
},
"cell_type": "code",
"source": [
"from wordcloud import WordCloud, STOPWORDS\n",
"stopwords = set(STOPWORDS)\n",
"stopwords.add(\"RT\")\n",
"\n",
"print(type(STOPWORDS))\n",
"\n",
"import random\n",
"\n",
"def random_color_func(word=None, font_size=None, position=None, orientation=None, font_path=None, random_state=None):\n",
" h = 344\n",
" s = int(100.0 * 255.0 / 255.0)\n",
" l = int(100.0 * float(random_state.randint(60, 120)) / 255.0)\n",
" return \"hsl({}, {}%, {}%)\".format(h, s, l)\n",
"\n",
"wordcloud = WordCloud(\n",
" background_color='white',\n",
" stopwords=stopwords,\n",
" max_words=200,\n",
" max_font_size=60, \n",
" random_state=42\n",
" ).generate(str(df.loc[df[\"category\"]==\"offensive_language\"].text))\n",
"print(wordcloud)\n",
"fig = plt.figure(1)\n",
"plt.imshow(wordcloud.recolor(color_func= random_color_func, random_state=3),\n",
" interpolation=\"bilinear\")\n",
"plt.axis('off')\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true,
"id": "AgmUhmCRFRvR",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 215
},
"outputId": "05118b87-4197-4e43-db84-888c344a4a0a"
},
"cell_type": "code",
"source": [
"\n",
"\n",
"def random_color_func(word=None, font_size=None, position=None, orientation=None, font_path=None, random_state=None):\n",
" h = 20\n",
" s = int(100.0 * 255.0 / 255.0)\n",
" l = int(100.0 * float(random_state.randint(60, 120)) / 255.0)\n",
" return \"hsl({}, {}%, {}%)\".format(h, s, l)\n",
"\n",
"wordcloud = WordCloud(\n",
" background_color='white',\n",
" stopwords=stopwords,\n",
" max_words=200,\n",
" max_font_size=60, \n",
" random_state=42\n",
" ).generate(str((df.loc[df[\"category\"]==\"neither\"].text)))\n",
"print(wordcloud)\n",
"fig = plt.figure(1)\n",
"plt.imshow(wordcloud.recolor(color_func= random_color_func, random_state=3),\n",
" interpolation=\"bilinear\")\n",
"plt.axis('off')\n",
"plt.show()\n",
"\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true,
"id": "-4qHG8GIFRvS",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 215
},
"outputId": "a35d2749-018e-4bd0-8f03-0785c3ecab29"
},
"cell_type": "code",
"source": [
"stopwords.add(\"Name\")\n",
"\n",
"def random_color_func(word=None, font_size=None, position=None, orientation=None, font_path=None, random_state=None):\n",
" h = 180\n",
" s = int(100.0 * 255.0 / 255.0)\n",
" l = int(100.0 * float(random_state.randint(60, 120)) / 255.0)\n",
" return \"hsl({}, {}%, {}%)\".format(h, s, l)\n",
"\n",
"wordcloud = WordCloud(\n",
" background_color='white',\n",
" stopwords=stopwords,\n",
" max_words=200,\n",
" max_font_size=60, \n",
" random_state=42\n",
" ).generate(str((df.loc[df[\"category\"]==\"hate_speech\"].text)))\n",
"print(wordcloud)\n",
"fig = plt.figure(1)\n",
"plt.imshow(wordcloud.recolor(color_func= random_color_func, random_state=3),\n",
" interpolation=\"bilinear\")\n",
"plt.axis('off')\n",
"plt.show()\n",
"\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"id": "CgC-Jdi5FRvS"
},
"cell_type": "markdown",
"source": [
"## Build TensorFlow input \n",
"[Reference](https://www.tensorflow.org/guide/data)"
]
},
{
"metadata": {
"trusted": true,
"id": "7tFK36cYFRvS"
},
"cell_type": "code",
"source": [
"train_ds = tf.data.Dataset.from_tensor_slices((df_train.text.values, df_train.label.values))\n",
"val_ds = tf.data.Dataset.from_tensor_slices((df_val.text.values, df_val.label.values))\n",
"test_ds = tf.data.Dataset.from_tensor_slices((df_test.text.values, df_test.label.values))"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "hb96Y5daFRvS",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b226eacd-76f1-4a28-cd8f-41066de708b2"
},
"cell_type": "code",
"source": [
"train_ds"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 61
}
]
},
{
"metadata": {
"id": "jUBqYoupFRvS"
},
"cell_type": "markdown",
"source": [
"While tf.data tries to propagate shape information, the default settings of Dataset.batch result in an unknown batch size because the last batch may not be full. Note the Nones in the shape:\n",
"\n",
"batched_dataset\n",
"```\n",
"\n",
"```\n",
"Use the drop_remainder argument to ignore that last batch, and get full shape propagation:"
]
},
{
"metadata": {
"trusted": true,
"id": "c3cO60xrFRvS",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b96cd584-30ce-43e8-db7a-b3846d422a53"
},
"cell_type": "code",
"source": [
"train_ds = train_ds.shuffle(len(df_train)).batch(32, drop_remainder=False)\n",
"train_ds"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 62
}
]
},
{
"metadata": {
"trusted": true,
"id": "2QZFSyghFRvT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5ca268a5-f192-47fb-8f4f-5b88b492c249"
},
"cell_type": "code",
"source": [
"val_ds = val_ds.shuffle(len(df_val)).batch(32, drop_remainder=False)\n",
"val_ds"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 63
}
]
},
{
"metadata": {
"trusted": true,
"id": "XgJHu9mvFRvT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9fa30534-98f4-4b12-825e-aaf77ac6784f"
},
"cell_type": "code",
"source": [
"test_ds = test_ds.shuffle(len(df_test)).batch(32, drop_remainder=False)\n",
"test_ds"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 64
}
]
},
{
"metadata": {
"id": "1U3-aDTrFRvT"
},
"cell_type": "markdown",
"source": [
"# Printing some Tweets"
]
},
{
"metadata": {
"trusted": true,
"id": "uzEoGsXgFRvT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "0488b9f5-f767-4d79-a013-50421a01292b"
},
"cell_type": "code",
"source": [
"for feat, targ in train_ds.take(1):\n",
" print ('Features: {}, Target: {}'.format(feat, targ))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Features: [b\"Chuck E Cheese isn't the same as when I was a kid. Chuck a skinny bitch & they took away the ball pit. Man wtf\"\n",
" b'All I ask for is love n loyalty. . nothing else. No arguing over other bitches/niggas , no he say she say , no social media bs... Just love'\n",
" b'RT @IcySoleOnline: Should of been a cork sole RT @Johnny_Blaz3: why is there no cork on the WTL XI?'\n",
" b\"I can't argue with a bitch over a nigga...\"\n",
" b'@ArianaGrande this bitch @cumwiththewind'\n",
" b'Any where the $$$ at is where the pussy go...'\n",
" b'@_chrisbrownn14 fuck bitches get money'\n",
" b'RT @RuNeshaShamia: Drunk inlove with my gun bae bitch Beyoncé 😈😏'\n",
" b'#porn,#android,#iphone,#ipad,#sex,#xxx, | #CloseUp | cumshot on beauty pussy closeup http://t.co/UonsdrAuV0'\n",
" b'Photo: Mustard on the beat hoe #wdywt http://t.co/DKYqcnRuNO'\n",
" b\"@daggerbyte @MichaelSmartGuy @spacej_me @pizzahut Given the rickety slant of that roof I'd say more like Pizza Shanty\"\n",
" b\"Don't throw me the pussy! Let me come get it\"\n",
" b'“@_getMEAUXmoney: Nobody cares about KIRK Bosley no”I care bout my cousin hoe'\n",
" b'RT @_MindAtEase: Fuck twerking bitch can you cook'\n",
" b'Reminiscin on my swinger days, when I drove a Caddy and my bitch sported a finger wave'\n",
" b'You just salty you ugly and white RT @JayWiz614: @ItsNotHarold they all some hoes'\n",
" b'Bull color schemes?\\n\\nDULL, you stupid bitch. DUKE WEARS BOLD PASTEL COLORS.\\n\\nBull makes no sense.'\n",
" b'She wanna talk shit.. but then have a friendship .. fuck you bitch and I hope u offended'\n",
" b'“@2hood2bgud: @uce_INA 👀👂 a bitch is all ears...” My mama told me not to talk to bitches so...😴😂jk'\n",
" b\"' I will never date a female with another niggah name on her lol . Wtf .\"\n",
" b'RT @Mike_daniels_YG: My last bitch was 18...bitch came in da trap tryna Vine da Dope fuh likes...her funeral was 3 days after dat'\n",
" b'Every gook in #LosAngeles should be deported or killed.'\n",
" b\"RT @prime13_time: Words from jac.. You can't save every hoe.. You got to let them hurt.\"\n",
" b'If my bitch wear a bacon bra....im taking that ass to funky town'\n",
" b'RT @Maxicat: Charlie Rangel Re-Writes History: On GOP “They Think They Won The Civil War” http://t.co/moCeUBRUTf'\n",
" b\"I ain't cutting my hair for u hoes at hc\"\n",
" b'Breakfast fried chicken jerk chicken Tater tots white rice nd press yellow rice nd beans Mac nd cheese http://t.co/Usz8gJnZl0'\n",
" b'RT @StevStiffler: If her bio says \"Only God can judge me\" she\\'s a hoe.'\n",
" b\"Dudes proudly volunteer to be baby daddy #2,3,4 and 5.Whether they'll be a father or not all over some pussy property.\"\n",
" b'Nyomi Banxxx and Skin Diamond the last of the good ebony porn bitches.'\n",
" b'The hoes be rated E !'\n",
" b'Look at this pussy @YatchakHannah http://t.co/cPyO5YP005'], Target: [1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 0 1 1 2 1 2 1 1 1 1 1]\n"
]
}
]
},
{
"metadata": {
"id": "zNx1iVn2FRvT"
},
"cell_type": "markdown",
"source": [
"# Loading models from TensorFlow Hub"
]
},
{
"metadata": {
"trusted": true,
"id": "xDoQeIMIFRvT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2faf8e37-3719-406e-c7f1-538f71e3466d"
},
"cell_type": "code",
"source": [
"bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8' \n",
"#bert_model_name = 'bert_en_uncased_L-12_H-768_A-12'\n",
"\n",
"map_name_to_handle = {\n",
" 'bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',\n",
" 'bert_en_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',\n",
" 'bert_multi_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',\n",
" 'albert_en_base':\n",
" 'https://tfhub.dev/tensorflow/albert_en_base/2',\n",
" 'electra_small':\n",
" 'https://tfhub.dev/google/electra_small/2',\n",
" 'electra_base':\n",
" 'https://tfhub.dev/google/electra_base/2',\n",
" 'experts_pubmed':\n",
" 'https://tfhub.dev/google/experts/bert/pubmed/2',\n",
" 'experts_wiki_books':\n",
" 'https://tfhub.dev/google/experts/bert/wiki_books/2',\n",
" 'talking-heads_base':\n",
" 'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',\n",
"}\n",
"\n",
"map_model_to_preprocess = {\n",
" 'bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'bert_en_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'bert_multi_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/1',\n",
" 'albert_en_base':\n",
" 'https://tfhub.dev/tensorflow/albert_en_preprocess/1',\n",
" 'electra_small':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'electra_base':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'experts_pubmed':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'experts_wiki_books':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
" 'talking-heads_base':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
"}\n",
"\n",
"tfhub_handle_encoder = map_name_to_handle[bert_model_name]\n",
"tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]\n",
"\n",
"print(f'BERT model selected : {tfhub_handle_encoder}')\n",
"print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"BERT model selected : https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1\n",
"Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1\n"
]
}
]
},
{
"metadata": {
"id": "ZZ5KDC9DFRvU"
},
"cell_type": "markdown",
"source": [
"I've chosen \"bert_en_uncased_L-12_H-768_A-12\"\n",
"\n",
"This TF Hub model uses the implementation of BERT from the TensorFlow Models repository on GitHub at tensorflow/models/official/nlp/bert. It uses L=12 hidden layers (i.e., Transformer blocks), a hidden size of H=768, and A=12 attention heads."
]
},
{
"metadata": {
"id": "QxFlWyGSFRvU"
},
"cell_type": "markdown",
"source": [
"### The preprocessing model\n",
"\n",
"Text inputs need to be transformed to numeric token ids and arranged in several Tensors before being input to BERT. TensorFlow Hub provides a matching preprocessing model for each of the BERT models, which implements this transformation using TF ops from the TF.text library. Hence, It is not necessary to run pure Python code outside the TensorFlow model to preprocess text.\n",
"\n",
"The preprocessing model must be the one referenced by the documentation of the BERT model, which can be read at the URL printed above. For BERT models from the drop-down above, the preprocessing model is selected automatically."
]
},
{
"metadata": {
"trusted": true,
"id": "qK7tsj0NFRvU"
},
"cell_type": "code",
"source": [
"bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"id": "oHpgchbbFRvU"
},
"cell_type": "markdown",
"source": [
"Let's try the preprocessing model on some text and see the output:"
]
},
{
"metadata": {
"trusted": true,
"id": "xVYTkEf4FRvU",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2c4ca530-4139-4683-cb31-67432c7cea5c"
},
"cell_type": "code",
"source": [
"for text_batch, label_batch in train_ds.take(1):\n",
" for i in range(1):\n",
" tweet = text_batch.numpy()[i]\n",
" print(f'Tweet: {text_batch.numpy()[i]}')\n",
" label = label_batch.numpy()[i]\n",
" print(f'Label : {label}')\n",
"\n",
"text_test = ['this is such an amazing movie!']\n",
"text_test = [tweet]\n",
"\n",
"\n",
"text_preprocessed = bert_preprocess_model(text_test)\n",
"\n",
"print(f'Keys : {list(text_preprocessed.keys())}')\n",
"print(f'Shape : {text_preprocessed[\"input_word_ids\"].shape}')\n",
"print(f'Word Ids : {text_preprocessed[\"input_word_ids\"][0, :12]}')\n",
"print(f'Input Mask : {text_preprocessed[\"input_mask\"][0, :12]}')\n",
"print(f'Type Ids : {text_preprocessed[\"input_type_ids\"][0, :12]}')\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Tweet: b'RT @memethagreat_: When you know you dealing with hoe 🙇'\n",
"Label : 1\n",
"Keys : ['input_word_ids', 'input_mask', 'input_type_ids']\n",
"Shape : (1, 128)\n",
"Word Ids : [ 101 19387 1030 2033 11368 3270 17603 4017 1035 1024 2043 2017]\n",
"Input Mask : [1 1 1 1 1 1 1 1 1 1 1 1]\n",
"Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "rna5wcS1FRvV"
},
"cell_type": "code",
"source": [
"bert_model = hub.KerasLayer(tfhub_handle_encoder)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "AARRd8FwFRvV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f86d474d-6f1e-4845-f03f-fd035c41b0e8"
},
"cell_type": "code",
"source": [
"bert_results = bert_model(text_preprocessed)\n",
"\n",
"print(f'Loaded BERT: {tfhub_handle_encoder}')\n",
"print(f'Pooled Outputs Shape:{bert_results[\"pooled_output\"].shape}')\n",
"print(f'Pooled Outputs Values:{bert_results[\"pooled_output\"][0, :12]}')\n",
"print(f'Sequence Outputs Shape:{bert_results[\"sequence_output\"].shape}')\n",
"print(f'Sequence Outputs Values:{bert_results[\"sequence_output\"][0, :12]}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loaded BERT: https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1\n",
"Pooled Outputs Shape:(1, 512)\n",
"Pooled Outputs Values:[ 0.70658135 0.9744342 -0.10705773 0.28930727 0.01819332 0.8509703\n",
" 0.9979883 -0.99360496 -0.02548804 -0.997787 0.21190482 -0.6805903 ]\n",
"Sequence Outputs Shape:(1, 128, 512)\n",
"Sequence Outputs Values:[[ 0.6834601 0.48724097 0.6317546 ... 0.6032436 0.6348365\n",
" 0.76171243]\n",
" [-0.07604086 0.2479679 0.06419381 ... 0.2410545 0.9862181\n",
" 0.6980718 ]\n",
" [-0.35939074 0.75332546 0.3006973 ... -0.21788383 -0.40428442\n",
" 1.4083261 ]\n",
" ...\n",
" [ 0.3959934 0.50379336 1.2242035 ... -0.22907454 -0.22878148\n",
" -0.3488557 ]\n",
" [ 0.5115414 -0.6881789 -0.01754081 ... -1.9684261 -0.19548538\n",
" 0.25261897]\n",
" [ 0.5218599 0.8258102 -0.24864265 ... -0.7810521 -0.39489534\n",
" 0.7557944 ]]\n"
]
}
]
},
{
"metadata": {
"id": "LhKlsuQoFRvV"
},
"cell_type": "markdown",
"source": [
"# Techniques to deal with unbalanced data"
]
},
{
"metadata": {
"id": "R9R4RbFkFRvV"
},
"cell_type": "markdown",
"source": [
"### Calculate class weights\n",
"\n",
"One of the goals is to identify hate speech, but we don't have very many of those samples to work with, so I would want to have the classifier heavily weight the few examples that are available. I am going to do this by passing Keras weights for each class through a parameter. These will cause the model to \"pay more attention\" to examples from an under-represented class."
]
},
{
"metadata": {
"trusted": true,
"id": "ChzMuXRhFRvV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "426d049e-322a-46f4-fa38-d7ef66b8c7af"
},
"cell_type": "code",
"source": [
"weight_for_0 = (1 / hate)*(total)/3.0 \n",
"weight_for_1 = (1 / ofensive)*(total)/3.0\n",
"weight_for_2 = (1 / neither)*(total)/3.0\n",
"\n",
"\n",
"class_weight = {0: weight_for_0, 1: weight_for_1, 2: weight_for_2}\n",
"\n",
"print('Weight for class 0: {:.2f}'.format(weight_for_0))\n",
"print('Weight for class 1: {:.2f}'.format(weight_for_1))\n",
"print('Weight for class 2: {:.2f}'.format(weight_for_2))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Weight for class 0: 5.78\n",
"Weight for class 1: 0.43\n",
"Weight for class 2: 1.98\n"
]
}
]
},
{
"metadata": {
"id": "7sEfBfCDFRvV"
},
"cell_type": "markdown",
"source": [
"### Set the correct initial bias\n",
"\n",
"These initial guesses (for the bias) are not great. The dataset is imbalanced. Set the output layer's bias to reflect that (See: [A Recipe for Training Neural Networks: \"init well\"](http://karpathy.github.io/2019/04/25/recipe/#2-set-up-the-end-to-end-trainingevaluation-skeleton--get-dumb-baselines)). This can help with initial convergence.\n",
"\n",
"With the default bias initialization the loss should be about log(1/n_classes): math.log(3) = 1,098612"
]
},
{
"metadata": {
"trusted": true,
"id": "-pXwCoxfFRvV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ea66da54-f8d8-4329-c113-8dcd58c74ab4"
},
"cell_type": "code",
"source": [
"#initial_output_bias = np.array([3.938462, 6.535164, 5.])\n",
"initial_output_bias = np.array([3.938462, 15, 5.])\n",
"initial_output_bias "
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 3.938462, 15. , 5. ])"
]
},
"metadata": {},
"execution_count": 72
}
]
},
{
"metadata": {
"id": "Ey9BmFmpFRvW"
},
"cell_type": "markdown",
"source": [
"# BERT + MLP\n",
"\n",
"I am going to create a simple fine-tuned model, with the preprocessing model, the selected BERT model, one Dense and a Dropout layer."
]
},
{
"metadata": {
"trusted": true,
"id": "iMTAzNKyFRvW"
},
"cell_type": "code",
"source": [
"def build_classifier_model(output_bias=None):\n",
" if output_bias is not None:\n",
" output_bias = tf.keras.initializers.Constant(output_bias)\n",
" #print(output_bias)\n",
" \n",
" text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
" preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n",
" encoder_inputs = preprocessing_layer(text_input)\n",
" encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n",
" outputs = encoder(encoder_inputs)\n",
" net = outputs['pooled_output']\n",
" net = tf.keras.layers.Dense(512, activation=\"relu\")(net)\n",
" net = tf.keras.layers.Dropout(0.2)(net)\n",
"# net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)\n",
" net = tf.keras.layers.Dense(3, activation=\"softmax\", name='classifier', bias_initializer=output_bias)(net)\n",
" \n",
" return tf.keras.Model(text_input, net)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "fMXFr9ivFRvW"
},
"cell_type": "code",
"source": [
"classifier_model = build_classifier_model(output_bias=initial_output_bias)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"bert_raw_result = classifier_model(tf.constant(text_test))\n",
"print(tf.sigmoid(bert_raw_result))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "E8rYeeX-DvMU",
"outputId": "29ae9a1e-85cc-43b0-baac-e2198bcc1f43"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tf.Tensor([[0.50000215 0.731055 0.5000024 ]], shape=(1, 3), dtype=float32)\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "ABU563dqFRvW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3183bc5b-2a3b-40a7-d8c0-62ea2de5d32a"
},
"cell_type": "code",
"source": [
"classifier_model.get_weights()[-1]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 3.938462, 15. , 5. ], dtype=float32)"
]
},
"metadata": {},
"execution_count": 75
}
]
},
{
"metadata": {
"trusted": true,
"id": "Nhvy3UL6FRvW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fc74203b-36c2-4ee3-f639-7d6e1cd5ec96"
},
"cell_type": "code",
"source": [
"classifier_model.summary()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model_2\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" text (InputLayer) [(None,)] 0 [] \n",
" \n",
" preprocessing (KerasLayer) {'input_word_ids': 0 ['text[0][0]'] \n",
" (None, 128), \n",
" 'input_type_ids': \n",
" (None, 128), \n",
" 'input_mask': (Non \n",
" e, 128)} \n",
" \n",
" BERT_encoder (KerasLayer) {'sequence_output': 28763649 ['preprocessing[0][0]', \n",
" (None, 128, 512), 'preprocessing[0][1]', \n",
" 'default': (None, 'preprocessing[0][2]'] \n",
" 512), \n",
" 'pooled_output': ( \n",
" None, 512), \n",
" 'encoder_outputs': \n",
" [(None, 128, 512), \n",
" (None, 128, 512), \n",
" (None, 128, 512), \n",
" (None, 128, 512)]} \n",
" \n",
" dense_2 (Dense) (None, 512) 262656 ['BERT_encoder[0][5]'] \n",
" \n",
" dropout_2 (Dropout) (None, 512) 0 ['dense_2[0][0]'] \n",
" \n",
" classifier (Dense) (None, 3) 1539 ['dropout_2[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 29,027,844\n",
"Trainable params: 29,027,843\n",
"Non-trainable params: 1\n",
"__________________________________________________________________________________________________\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "Zo4kMYEFFRvW",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"outputId": "f8668dec-70cf-4f06-e4fe-a9170871c40d"
},
"cell_type": "code",
"source": [
"tf.keras.utils.plot_model(classifier_model)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 77
}
]
},
{
"metadata": {
"trusted": true,
"id": "lT8Wr1xrFRvW"
},
"cell_type": "code",
"source": [
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"metrics = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
"#metrics = tf.metrics.Accuracy()"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "EsNdj_pHFRvW"
},
"cell_type": "code",
"source": [
"epochs = 15\n",
"steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()\n",
"num_train_steps = steps_per_epoch * epochs\n",
"num_warmup_steps = int(0.1*num_train_steps)\n",
"\n",
"init_lr = 3e-5\n",
"optimizer = optimization.create_optimizer(init_lr=init_lr,\n",
" num_train_steps=num_train_steps,\n",
" num_warmup_steps=num_warmup_steps,\n",
" optimizer_type='adamw')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"epochs, steps_per_epoch, num_train_steps, num_warmup_steps"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IADkXHnPMUFe",
"outputId": "c9369505-6940-40ff-f883-c367904c01eb"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(15, 628, 9420, 942)"
]
},
"metadata": {},
"execution_count": 80
}
]
},
{
"metadata": {
"trusted": true,
"id": "jU6g-j6RFRvW"
},
"cell_type": "code",
"source": [
"# classifier_model.compile(optimizer=optimizer,\n",
"# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
"# metrics=['accuracy'])\n",
"classifier_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "OPyseUFVFRvX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f7f798a4-22aa-440d-c195-e6a730221a47"
},
"cell_type": "code",
"source": [
"print(f'Training model with {tfhub_handle_encoder}')\n",
"history = classifier_model.fit(x=train_ds,\n",
" validation_data=val_ds,\n",
" epochs=epochs,\n",
" # The class weights go here\n",
" class_weight=class_weight\n",
")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training model with https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1\n",
"Epoch 1/15\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/keras/backend.py:5585: UserWarning: \"`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Softmax activation and thus does not represent logits. Was this intended?\n",
" output, from_logits = _get_logits(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"628/628 [==============================] - 172s 260ms/step - loss: 1.9218 - accuracy: 0.6873 - val_loss: 0.4483 - val_accuracy: 0.8292\n",
"Epoch 2/15\n",
"628/628 [==============================] - 162s 258ms/step - loss: 0.5943 - accuracy: 0.8236 - val_loss: 0.4338 - val_accuracy: 0.8180\n",
"Epoch 3/15\n",
"628/628 [==============================] - 162s 258ms/step - loss: 0.4613 - accuracy: 0.8532 - val_loss: 0.3169 - val_accuracy: 0.8897\n",
"Epoch 4/15\n",
"628/628 [==============================] - 158s 252ms/step - loss: 0.4038 - accuracy: 0.8870 - val_loss: 0.3273 - val_accuracy: 0.8929\n",
"Epoch 5/15\n",
"628/628 [==============================] - 161s 256ms/step - loss: 0.3218 - accuracy: 0.9049 - val_loss: 0.3407 - val_accuracy: 0.8929\n",
"Epoch 6/15\n",
"628/628 [==============================] - 158s 252ms/step - loss: 0.2585 - accuracy: 0.9289 - val_loss: 0.3880 - val_accuracy: 0.8817\n",
"Epoch 7/15\n",
"628/628 [==============================] - 157s 250ms/step - loss: 0.1978 - accuracy: 0.9468 - val_loss: 0.4062 - val_accuracy: 0.9045\n",
"Epoch 8/15\n",
"628/628 [==============================] - 160s 255ms/step - loss: 0.1618 - accuracy: 0.9621 - val_loss: 0.4794 - val_accuracy: 0.8920\n",
"Epoch 9/15\n",
"628/628 [==============================] - 157s 250ms/step - loss: 0.1261 - accuracy: 0.9721 - val_loss: 0.5110 - val_accuracy: 0.9072\n",
"Epoch 10/15\n",
"628/628 [==============================] - 161s 257ms/step - loss: 0.1078 - accuracy: 0.9778 - val_loss: 0.6166 - val_accuracy: 0.8879\n",
"Epoch 11/15\n",
"628/628 [==============================] - 161s 256ms/step - loss: 0.0832 - accuracy: 0.9830 - val_loss: 0.6481 - val_accuracy: 0.9027\n",
"Epoch 12/15\n",
"628/628 [==============================] - 157s 250ms/step - loss: 0.0553 - accuracy: 0.9885 - val_loss: 0.7068 - val_accuracy: 0.9041\n",
"Epoch 13/15\n",
"628/628 [==============================] - 157s 251ms/step - loss: 0.0600 - accuracy: 0.9896 - val_loss: 0.7108 - val_accuracy: 0.9077\n",
"Epoch 14/15\n",
"628/628 [==============================] - 158s 251ms/step - loss: 0.0371 - accuracy: 0.9929 - val_loss: 0.7470 - val_accuracy: 0.9072\n",
"Epoch 15/15\n",
"628/628 [==============================] - 158s 251ms/step - loss: 0.0368 - accuracy: 0.9927 - val_loss: 0.7540 - val_accuracy: 0.9054\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "iyIkTCxlFRvX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e6e894a1-d161-41fc-81b7-52dfd6d6b7bd"
},
"cell_type": "code",
"source": [
"loss, accuracy = classifier_model.evaluate(test_ds)\n",
"\n",
"print(f'Loss: {loss}')\n",
"print(f'Accuracy: {accuracy}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"78/78 [==============================] - 10s 125ms/step - loss: 0.7863 - accuracy: 0.8975\n",
"Loss: 0.786299467086792\n",
"Accuracy: 0.8975393176078796\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "9QxXXAbFFRvX",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 656
},
"outputId": "5443a9d9-5a4d-493f-8768-5bfc75df1187"
},
"cell_type": "code",
"source": [
"history_dict = history.history\n",
"print(history_dict.keys())\n",
"\n",
"acc = history_dict['accuracy']\n",
"val_acc = history_dict['val_accuracy']\n",
"# acc = history_dict['binary_accuracy']\n",
"# val_acc = history_dict['val_binary_accuracy']\n",
"loss = history_dict['loss']\n",
"val_loss = history_dict['val_loss']\n",
"\n",
"epochs = range(1, len(acc) + 1)\n",
"fig = plt.figure(figsize=(12, 10))\n",
"fig.tight_layout()\n",
"\n",
"plt.subplot(2, 1, 1)\n",
"# \"bo\" is for \"blue dot\"\n",
"plt.plot(epochs, loss, 'r', label='Training loss')\n",
"# b is for \"solid blue line\"\n",
"plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
"plt.title('Training and validation loss')\n",
"# plt.xlabel('Epochs')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.subplot(2, 1, 2)\n",
"plt.plot(epochs, acc, 'r', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend(loc='lower right')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 50
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"id": "ZlkpD3IwFRvX"
},
"cell_type": "markdown",
"source": [
"## Export for inference\n",
"\n",
"Now you just save your fine-tuned model for later use."
]
},
{
"metadata": {
"trusted": true,
"id": "z4F1ZgkyFRvX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b0ba59bb-add5-4657-efaf-1e8e6eb548d6"
},
"cell_type": "code",
"source": [
"dataset_name = 'mpl_hate_speech'\n",
"saved_model_path = './{}_bert'.format(dataset_name.replace('/', '_'))\n",
"\n",
"classifier_model.save(saved_model_path, include_optimizer=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 124). These functions will not be directly callable after loading.\n"
]
}
]
},
{
"metadata": {
"id": "CO3dJu3oFRvX"
},
"cell_type": "markdown",
"source": [
"# Results for MLP"
]
},
{
"metadata": {
"trusted": true,
"id": "qlHgotskFRvX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1913e081-e099-4716-c477-bf7293b6d0bf"
},
"cell_type": "code",
"source": [
"result = classifier_model.predict(test_ds)\n",
"print(result.shape)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"78/78 [==============================] - 12s 141ms/step\n",
"(2479, 3)\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "5HuQJ80PFRvX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9b30f604-919c-4011-ad3a-168cc4ca8853"
},
"cell_type": "code",
"source": [
"result[0:2]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0.00000052, 0.99999785, 0.00000165],\n",
" [0.00000052, 0.99998534, 0.00001417]], dtype=float32)"
]
},
"metadata": {},
"execution_count": 53
}
]
},
{
"metadata": {
"trusted": true,
"id": "Pq6d4j90FRvX"
},
"cell_type": "code",
"source": [
"classes = np.argmax(result, axis=-1)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"id": "AnDOh5iVFRvY"
},
"cell_type": "markdown",
"source": [
"### Doing predictions and saving to np.array"
]
},
{
"metadata": {
"trusted": true,
"id": "v3CveqdJFRvY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a798845d-e902-424d-a803-8ec8678e5866"
},
"cell_type": "code",
"source": [
"tweet = []\n",
"test_labels = []\n",
"predictions = []\n",
"for tweet, labels in test_ds.take(-1):\n",
" tweet = tweet.numpy()\n",
" test_labels.append(labels.numpy())\n",
" predictions.append(classifier_model.predict(tweet))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1/1 [==============================] - 1s 771ms/step\n",
"1/1 [==============================] - 0s 84ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 83ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 88ms/step\n",
"1/1 [==============================] - 0s 87ms/step\n",
"1/1 [==============================] - 0s 77ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 75ms/step\n",
"1/1 [==============================] - 0s 82ms/step\n",
"1/1 [==============================] - 0s 75ms/step\n",
"1/1 [==============================] - 0s 69ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 72ms/step\n",
"1/1 [==============================] - 0s 79ms/step\n",
"1/1 [==============================] - 0s 87ms/step\n",
"1/1 [==============================] - 0s 77ms/step\n",
"1/1 [==============================] - 0s 88ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 136ms/step\n",
"1/1 [==============================] - 0s 136ms/step\n",
"1/1 [==============================] - 0s 138ms/step\n",
"1/1 [==============================] - 0s 143ms/step\n",
"1/1 [==============================] - 0s 141ms/step\n",
"1/1 [==============================] - 0s 136ms/step\n",
"1/1 [==============================] - 0s 139ms/step\n",
"1/1 [==============================] - 0s 140ms/step\n",
"1/1 [==============================] - 0s 160ms/step\n",
"1/1 [==============================] - 0s 155ms/step\n",
"1/1 [==============================] - 0s 150ms/step\n",
"1/1 [==============================] - 0s 146ms/step\n",
"1/1 [==============================] - 0s 151ms/step\n",
"1/1 [==============================] - 0s 93ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 81ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 81ms/step\n",
"1/1 [==============================] - 0s 89ms/step\n",
"1/1 [==============================] - 0s 77ms/step\n",
"1/1 [==============================] - 0s 78ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 90ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 82ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 78ms/step\n",
"1/1 [==============================] - 0s 87ms/step\n",
"1/1 [==============================] - 0s 75ms/step\n",
"1/1 [==============================] - 0s 74ms/step\n",
"1/1 [==============================] - 0s 87ms/step\n",
"1/1 [==============================] - 0s 84ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 83ms/step\n",
"1/1 [==============================] - 0s 77ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 73ms/step\n",
"1/1 [==============================] - 0s 82ms/step\n",
"1/1 [==============================] - 0s 74ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 75ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 122ms/step\n",
"1/1 [==============================] - 0s 79ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 77ms/step\n",
"1/1 [==============================] - 0s 74ms/step\n",
"1/1 [==============================] - 0s 72ms/step\n",
"1/1 [==============================] - 0s 84ms/step\n",
"1/1 [==============================] - 0s 80ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 76ms/step\n",
"1/1 [==============================] - 0s 75ms/step\n",
"1/1 [==============================] - 0s 84ms/step\n",
"1/1 [==============================] - 1s 722ms/step\n"
]
}
]
},
{
"metadata": {
"trusted": true,
"id": "tPtBYe6BFRvY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4e947bd5-0087-4e8b-929c-3bc1b3cac4b6"
},
"cell_type": "code",
"source": [
"test_labels[0:2]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[array([2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
" array([1, 1, 1, 1, 2, 1, 1, 0, 1, 1, 1, 1, 2, 0, 1, 1, 2, 1, 1, 1, 2, 1,\n",
" 1, 1, 1, 0, 2, 1, 1, 2, 1, 1])]"
]
},
"metadata": {},
"execution_count": 56
}
]
},
{
"metadata": {
"trusted": true,
"id": "4JjivM3sFRvY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "6afe955d-7767-4fe8-b9bc-e9e07ca7a16f"
},
"cell_type": "code",
"source": [
"predictions[0:2]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[array([[0.00000006, 0.00000154, 0.99999845],\n",
" [0.00066261, 0.999332 , 0.00000539],\n",
" [0.00000088, 0.99999774, 0.00000145],\n",
" [0.00000096, 0.99999714, 0.00000188],\n",
" [0.0000005 , 0.99999845, 0.00000113],\n",
" [0.0000014 , 0.9999963 , 0.00000224],\n",
" [0.00000044, 0.9999976 , 0.00000186],\n",
" [0.0000008 , 0.99999714, 0.00000207],\n",
" [0.0026225 , 0.99736387, 0.00001364],\n",
" [0.00000101, 0.9999958 , 0.00000318],\n",
" [0.00032014, 0.9996766 , 0.00000333],\n",
" [0.00000168, 0.999997 , 0.0000013 ],\n",
" [0.00000011, 0.00000586, 0.99999404],\n",
" [0.9999267 , 0.00007198, 0.00000127],\n",
" [0.00000123, 0.99999475, 0.00000408],\n",
" [0.00000055, 0.99999774, 0.00000165],\n",
" [0.00121367, 0.9987759 , 0.00001046],\n",
" [0.00000089, 0.99999726, 0.00000187],\n",
" [0.00000229, 0.9999949 , 0.00000284],\n",
" [0.00000595, 0.9999914 , 0.00000268],\n",
" [0.00000382, 0.99999464, 0.00000154],\n",
" [0.00000107, 0.99999404, 0.00000488],\n",
" [0.0001139 , 0.9998839 , 0.00000215],\n",
" [0.00000049, 0.99999833, 0.00000116],\n",
" [0.99995697, 0.00004238, 0.0000006 ],\n",
" [0.00000162, 0.999995 , 0.00000328],\n",
" [0.00000084, 0.999995 , 0.00000414],\n",
" [0.00000209, 0.9999807 , 0.00001732],\n",
" [0.00000096, 0.99999726, 0.00000173],\n",
" [0.00000543, 0.99999213, 0.00000235],\n",
" [0.00000075, 0.9999963 , 0.00000301],\n",
" [0.00001862, 0.9999739 , 0.00000749]], dtype=float32),\n",
" array([[0.00000042, 0.999998 , 0.00000155],\n",
" [0.00000426, 0.999992 , 0.0000037 ],\n",
" [0.00109527, 0.99889886, 0.00000587],\n",
" [0.00000035, 0.9999975 , 0.00000219],\n",
" [0.00000041, 0.00005752, 0.99994206],\n",
" [0.98130816, 0.01868972, 0.00000209],\n",
" [0.0293135 , 0.9698654 , 0.00082111],\n",
" [0.00062872, 0.99937004, 0.00000125],\n",
" [0.00000103, 0.9999974 , 0.00000156],\n",
" [0.00000042, 0.99999774, 0.0000019 ],\n",
" [0.00000089, 0.999998 , 0.00000118],\n",
" [0.00000068, 0.999998 , 0.00000129],\n",
" [0.00000168, 0.00001529, 0.9999831 ],\n",
" [0.00292425, 0.99706656, 0.00000921],\n",
" [0.00000192, 0.9999969 , 0.00000116],\n",
" [0.9542036 , 0.04579007, 0.0000064 ],\n",
" [0.00000013, 0.00000313, 0.9999968 ],\n",
" [0.00000696, 0.9999907 , 0.00000241],\n",
" [0.00000062, 0.9999974 , 0.00000208],\n",
" [0.00002774, 0.99997103, 0.00000118],\n",
" [0.00063122, 0.00601086, 0.99335796],\n",
" [0.00000099, 0.9999975 , 0.00000158],\n",
" [0.00000068, 0.9999926 , 0.00000664],\n",
" [0.00000638, 0.9958839 , 0.00410963],\n",
" [0.0000165 , 0.9999771 , 0.00000648],\n",
" [0.99997306, 0.00002616, 0.00000084],\n",
" [0.00000027, 0.00000786, 0.9999919 ],\n",
" [0.00000057, 0.99999666, 0.00000272],\n",
" [0.00000023, 0.9999944 , 0.00000539],\n",
" [0.00000068, 0.00015671, 0.99984264],\n",
" [0.0000865 , 0.9999057 , 0.00000774],\n",
" [0.00000049, 0.99999785, 0.00000163]], dtype=float32)]"
]
},
"metadata": {},
"execution_count": 57
}
]
},
{
"metadata": {
"trusted": true,
"id": "MHWLNEejFRvY"
},
"cell_type": "code",
"source": [
"from itertools import chain\n",
"flatten_list = list(chain.from_iterable(predictions))\n",
"y_pred = np.argmax(flatten_list, axis=-1)"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true,
"id": "WUzvFYplFRvY"
},
"cell_type": "code",
"source": [
"y_test = np.array(list(chain.from_iterable(test_labels)))"
],
"execution_count": null,
"outputs": []
},
{
"metadata": {
"id": "OmMXdwFlFRvY"
},
"cell_type": "markdown",
"source": [
"# Confusion Matrix MLP"
]
},
{
"metadata": {
"trusted": true,
"id": "o0zYKNYvFRvY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "78f0a33d-93a4-4ed3-ed7c-fbb41b3dc4e5"
},
"cell_type": "code",
"source": [
"from sklearn.metrics import confusion_matrix\n",
"confusion_matrix(y_test, y_pred)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 50, 80, 13],\n",
" [ 71, 1817, 32],\n",
" [ 5, 53, 358]])"
]
},
"metadata": {},
"execution_count": 60
}
]
},
{
"cell_type": "markdown",
"source": [
"Save model"
],
"metadata": {
"id": "8g9WssqUBkuQ"
}
},
{
"cell_type": "code",
"source": [
"classifier_model.save('/content/drive/MyDrive/AI/hate_speech/classifier_model.h5', save_format='h5')"
],
"metadata": {
"id": "AKohkdFlOq8g"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# ## load manual with weight\n",
"# bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8'\n",
"# map_name_to_handle = {\n",
"# 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
"# 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',\n",
"# }\n",
"\n",
"# map_model_to_preprocess = {\n",
"# 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
"# 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1',\n",
"# }\n",
"\n",
"# tfhub_handle_encoder = map_name_to_handle[bert_model_name]\n",
"# tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]\n",
"\n",
"\n",
"# initial_output_bias = np.array([3.938462, 15, 5.])\n",
"\n",
"\n",
"# def build_classifier_model(output_bias=None):\n",
"# if output_bias is not None:\n",
"# output_bias = tf.keras.initializers.Constant(output_bias)\n",
" \n",
"# text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
"# preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n",
"# encoder_inputs = preprocessing_layer(text_input)\n",
"# encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n",
"# outputs = encoder(encoder_inputs)\n",
"# net = outputs['pooled_output']\n",
"# net = tf.keras.layers.Dense(512, activation=\"relu\")(net)\n",
"# net = tf.keras.layers.Dropout(0.2)(net)\n",
"# net = tf.keras.layers.Dense(3, activation=\"softmax\", name='classifier', bias_initializer=output_bias)(net)\n",
"# return tf.keras.Model(text_input, net)\n",
"\n",
"\n",
"# ## compile\n",
"# epochs = 15\n",
"# steps_per_epoch = 628\n",
"# num_train_steps = steps_per_epoch * epochs\n",
"# num_warmup_steps = int(0.1*num_train_steps)\n",
"# init_lr = 3e-5\n",
"\n",
"# optimizer = optimization.create_optimizer(init_lr=init_lr,\n",
"# num_train_steps=num_train_steps,\n",
"# num_warmup_steps=num_warmup_steps,\n",
"# optimizer_type='adamw')\n",
"# loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"# metrics = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
"\n",
"\n",
"# classifier_model = build_classifier_model(output_bias=initial_output_bias)\n",
"\n",
"# classifier_model.compile(optimizer=optimizer,\n",
"# loss=loss,\n",
"# metrics=metrics)\n",
"\n",
"# classifier_model.load_weights(checkpoint_path)\n"
],
"metadata": {
"id": "Ugop_mNNtLL8"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"provenance": []
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}