{ "cells": [ { "cell_type": "markdown", "source": [ "# Dataset processing\n", "\n", "This notebook processes the raw csv outputs from VAERS into Huggingface datasets. It shouldn't generally need to be run by the end user. " ], "metadata": { "collapsed": false }, "id": "35523bbeb2e03eae" }, { "cell_type": "code", "outputs": [], "source": [ "import pandas as pd\n", "import datasets\n", "import glob\n", "import tqdm.notebook as tqdm\n", "from sklearn.model_selection import train_test_split\n", "from typing import Tuple\n", "from datetime import datetime\n", "\n", "pd.set_option('future.no_silent_downcasting', True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:38.481853Z", "start_time": "2024-01-27T22:28:38.458294Z" } }, "id": "9362802d64424442", "execution_count": 15 }, { "cell_type": "code", "outputs": [], "source": [ "HF_URL: str = \"chrisvoncsefalvay/vaers-outcomes\"\n", "\n", "FLAG_COLUMNS: list = [\"DIED\", \"ER_VISIT\", \"HOSPITAL\", \"OFC_VISIT\", \"X_STAY\", \"DISABLE\"]\n", "DEMOGRAPHIC_COLUMNS: list = [\"AGE_YRS\", \"SEX\"]\n", "DERIVED_COLUMNS: list = [\"D_PRESENTED\"]\n", "ID_COLUMNS: list = [\"VAERS_ID\"]\n", "TEXT_COLUMNS: list = [\"SYMPTOM_TEXT\"]\n", "\n", "TEST_TRAIN_FRACTION: float = 0.3\n", "TRAIN_VAL_FRACTION: float = 0.5" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:38.498974Z", "start_time": "2024-01-27T22:28:38.486237Z" } }, "id": "34b77edf5a1fce96", "execution_count": 16 }, { "cell_type": "markdown", "source": [ "## Reading data files" ], "metadata": { "collapsed": false }, "id": "f5f84ddd06e9313e" }, { "cell_type": "code", "outputs": [], "source": [ "def read_aggregate(pattern: str) -> pd.DataFrame:\n", " files = glob.glob(f\"../data/{pattern}\")\n", " dfs = []\n", " for file in tqdm.tqdm(files):\n", " dfs.append(pd.read_csv(file, encoding=\"latin-1\", low_memory=False))\n", "\n", " res = pd.concat(dfs, ignore_index=True)\n", " \n", " print(f\"Processed {len(dfs)} files for a total of {len(res)} records.\")\n", " \n", " return res" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:38.508227Z", "start_time": "2024-01-27T22:28:38.500697Z" } }, "id": "a7772ed4b4b51868", "execution_count": 17 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": " 0%| | 0/1 [00:00 pd.DataFrame:\n", " for column in FLAG_COLUMNS + [\"ER_ED_VISIT\"]:\n", " df[column] = df[column].replace(\"Y\", 1).fillna(0).astype(int)\n", " \n", " df['ER_VISIT'] = df[['ER_VISIT', 'ER_ED_VISIT']].max(axis=1)\n", " \n", " df = df.drop(columns=['ER_ED_VISIT'])\n", " \n", " df = df.dropna(subset=['SYMPTOM_TEXT'])\n", " \n", " return df" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.603731Z", "start_time": "2024-01-27T22:28:39.590617Z" } }, "id": "9aad00c9fe40adb8", "execution_count": 20 }, { "cell_type": "code", "outputs": [], "source": [], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.604024Z", "start_time": "2024-01-27T22:28:39.593891Z" } }, "id": "b0fdcab6ee807404", "execution_count": 20 }, { "cell_type": "code", "outputs": [], "source": [ "data = recode(data)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.665777Z", "start_time": "2024-01-27T22:28:39.597946Z" } }, "id": "f23ee0eae1b70387", "execution_count": 21 }, { "cell_type": "markdown", "source": [ "## Derived fields\n", "\n", "We create the derived field `D_PRESENTED`. This is to provide a shorthand for patients who present in any way: ER, hospitalisation, office visit. It also comprises patients whose hospital stay is extended (`X_STAY`) as this is typically the consequence of presenting." ], "metadata": { "collapsed": false }, "id": "1c2f6b4fc2ae630b" }, { "cell_type": "code", "outputs": [], "source": [ "data['D_PRESENTED'] = data[['ER_VISIT', 'HOSPITAL', 'OFC_VISIT', 'X_STAY']].max(axis=1)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.679534Z", "start_time": "2024-01-27T22:28:39.667363Z" } }, "id": "678847c70756695e", "execution_count": 22 }, { "cell_type": "markdown", "source": [ "## Test/train/validate split\n", "\n", "We do a stratified split by age quintile and gender into test, train and validate sets." ], "metadata": { "collapsed": false }, "id": "dae902b111c8ef3c" }, { "cell_type": "code", "outputs": [], "source": [ "def stratified_split(df: pd.DataFrame, test_train_fraction: float, train_val_fraction: float, random_state: int = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n", " df['AGE_QUINTILE'] = pd.qcut(df['AGE_YRS'], 5, labels = False)\n", " df['STRATIFICATION_VARIABLE'] = df['SEX'].astype(str) + \"_\" + df['AGE_QUINTILE'].astype(str)\n", " df = df.drop(columns=['AGE_QUINTILE'])\n", " \n", " _, train = train_test_split(df, train_size=test_train_fraction, random_state=random_state, stratify=df.STRATIFICATION_VARIABLE)\n", " \n", " val, test = train_test_split(_, train_size=train_val_fraction, random_state=random_state, stratify=_.STRATIFICATION_VARIABLE)\n", " \n", " train = train.drop(columns=\"STRATIFICATION_VARIABLE\")\n", " val = val.drop(columns=\"STRATIFICATION_VARIABLE\")\n", " test = test.drop(columns=\"STRATIFICATION_VARIABLE\") \n", " \n", " return train, test, val" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.680497Z", "start_time": "2024-01-27T22:28:39.678055Z" } }, "id": "ddee47653c94ff02", "execution_count": 23 }, { "cell_type": "code", "outputs": [], "source": [ "train, test, val = stratified_split(data, TEST_TRAIN_FRACTION, TRAIN_VAL_FRACTION)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.863489Z", "start_time": "2024-01-27T22:28:39.680464Z" } }, "id": "bb16aaad0127ef7d", "execution_count": 24 }, { "cell_type": "markdown", "source": [ "## Converting to labels" ], "metadata": { "collapsed": false }, "id": "d61bfdc4a2879905" }, { "cell_type": "code", "outputs": [], "source": [ "def convert_to_dataset(df: pd.DataFrame) -> datasets.Dataset:\n", " df = df.loc[:, ID_COLUMNS + TEXT_COLUMNS + FLAG_COLUMNS + DERIVED_COLUMNS]\n", " \n", " # We create the labels – these have to be floats for multilabel classification that uses BCEWithLogitsLoss\n", " df.loc[:, \"labels\"] = df[FLAG_COLUMNS + DERIVED_COLUMNS].values.astype(float).tolist()\n", " \n", " print(f\"Building dataset with the following label order: {' '.join(FLAG_COLUMNS + DERIVED_COLUMNS)}\")\n", " \n", " # We drop the flag columns\n", " df = df.drop(columns=FLAG_COLUMNS).drop(columns=DERIVED_COLUMNS)\n", " \n", " # We rename the remaining columns\n", " df = df.rename(columns={\"SYMPTOM_TEXT\": \"text\", \"VAERS_ID\": \"id\"})\n", " \n", " return datasets.Dataset.from_pandas(df, preserve_index=False)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:39.867392Z", "start_time": "2024-01-27T22:28:39.864829Z" } }, "id": "3d602444d33b7130", "execution_count": 25 }, { "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n", "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n", "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n" ] } ], "source": [ "ds = datasets.DatasetDict()\n", "ds[\"train\"] = convert_to_dataset(train)\n", "ds[\"test\"] = convert_to_dataset(test)\n", "ds[\"val\"] = convert_to_dataset(val)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-01-27T22:28:40.207548Z", "start_time": "2024-01-27T22:28:39.872665Z" } }, "id": "e7c854a072956ca3", "execution_count": 26 }, { "cell_type": "markdown", "source": [ "## Saving to Huggingface Hub" ], "metadata": { "collapsed": false }, "id": "ec0167c068238f5a" }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00