{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Kunwar-Saaim Coding Challenge for Fatima Fellowship", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "TPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "eBpjBBZc6IvA" }, "source": [ "# Fatima Fellowship Quick Coding Challenge (Pick 1)\n", "\n", "Thank you for applying to the Fatima Fellowship. To help us select the Fellows and assess your ability to do machine learning research, we are asking that you complete a short coding challenge. Please pick **1 of these 5** coding challenges, whichever is most aligned with your interests. \n", "\n", "**Due date: 1 week**\n", "\n", "**How to submit**: Please make a copy of this colab notebook, add your code and results, and submit your colab notebook to the submission link below. If you have never used a colab notebook, [check out this video](https://www.youtube.com/watch?v=i-HnvsehuSw).\n", "\n", "**Submission link**: https://airtable.com/shrXy3QKSsO2yALd3" ] }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "vFNnwRYul8xh" } }, { "cell_type": "markdown", "metadata": { "id": "braBzmRpMe7_" }, "source": [ "# 1. Deep Learning for Vision" ] }, { "cell_type": "markdown", "metadata": { "id": "1IWw-NZf5WfF" }, "source": [ "**Upside down detector**: Train a model to detect if images are upside down\n", "\n", "* Pick a dataset of natural images (we suggest looking at datasets on the [Hugging Face Hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification&sort=downloads))\n", "* Synthetically turn some of images upside down. Create a training and test set.\n", "* Build a neural network (using Tensorflow, PyTorch, or any framework you like)\n", "* Train it to classify image orientation until a reasonable accuracy is reached\n", "* [Upload the the model to the Hugging Face Hub](https://huggingface.co/docs/hub/adding-a-model), and add a link to your model below.\n", "* Look at some of the images that were classified incorrectly. Please explain what you might do to improve your model's performance on these images in the future (you do not need to impelement these suggestions)\n", "\n", "**Submission instructions**: Please write your code below and include some examples of images that were classified" ] }, { "cell_type": "code", "source": [ "### WRITE YOUR CODE TO TRAIN THE MODEL HERE" ], "metadata": { "id": "K2GJaYBpw91T" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Write up**: \n", "* Link to the model on Hugging Face Hub: \n", "* Include some examples of misclassified images. Please explain what you might do to improve your model's performance on these images in the future (you do not need to impelement these suggestions)" ], "metadata": { "id": "qSeLed2JxvGI" } }, { "cell_type": "markdown", "metadata": { "id": "sFU9LTOyMiMj" }, "source": [ "# 2. Deep Learning for NLP\n", "\n", "**Fake news classifier**: Train a text classification model to detect fake news articles!\n", "\n", "* Download the dataset here: https://www.kaggle.com/clmentbisaillon/fake-and-real-news-dataset\n", "* Develop an NLP model for classification that uses a pretrained language model\n", "* Finetune your model on the dataset, and generate an AUC curve of your model on the test set of your choice. \n", "* [Upload the the model to the Hugging Face Hub](https://huggingface.co/docs/hub/adding-a-model), and add a link to your model below.\n", "* *Answer the following question*: Look at some of the news articles that were classified incorrectly. Please explain what you might do to improve your model's performance on these news articles in the future (you do not need to impelement these suggestions)" ] }, { "cell_type": "markdown", "source": [ "## https://huggingface.co/kunwwarsaaim/distill-bert-fake-news-detection" ], "metadata": { "id": "uVAL-L6mmaEd" } }, { "cell_type": "code", "source": [ "!pip install transformers" ], "metadata": { "id": "CRKM9SyZsWkl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import os\n", "import tensorflow as tf\n", "import pandas as pd\n", "from sklearn.utils import shuffle\n", "import transformers\n", "from transformers import AutoTokenizer, TFAutoModelForSequenceClassification" ], "metadata": { "id": "11B2b71RqwOG" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "try:\n", " tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n", " print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n", "except ValueError:\n", " raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')\n", "\n", "tf.config.experimental_connect_to_cluster(tpu)\n", "tf.tpu.experimental.initialize_tpu_system(tpu)\n", "tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n", "\n", "\n", "AUTOTUNE = tf.data.experimental.AUTOTUNE" ], "metadata": { "id": "OlwYIn9_pgnJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "batch_size=32 * tpu_strategy.num_replicas_in_sync\n", "print('Batch size:', batch_size)" ], "metadata": { "id": "by_5W3761YsW", "outputId": "ce809a60-fc25-4c31-adb2-ba9fa98d77a6", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Batch size: 256\n" ] } ] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n", "os.chdir('/content/drive/MyDrive/news_dataset')" ], "metadata": { "id": "QkGuEO-4jnav", "outputId": "2af82882-512a-4b82-e50c-27b6fc3f27fd", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "code", "source": [ "### WRITE YOUR CODE TO TRAIN THE MODEL HERE\n", "fake_news = pd.read_csv('Fake.csv')\n", "true_news = pd.read_csv('True.csv') " ], "metadata": { "id": "E90i018KyJH3" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "true_news['class_'] = [0]*len(true_news)\n", "fake_news['class_'] = [1]*len(fake_news)\n", "dataset = pd.concat([true_news,fake_news]).reset_index(drop=True)\n", "dataset = shuffle(dataset)\n", "\n", "dataset['data'] = dataset['title']+ ' ' + dataset['text']\n", "dataset.drop(['title','text','subject','date'],axis=1,inplace=True)\n", "dataset.drop_duplicates(inplace=True)" ], "metadata": { "id": "xkmNq2E1kUoq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "dataset.keys()" ], "metadata": { "id": "40h-FjdKssww", "outputId": "396b3f77-dc2a-4b4c-8239-4e93bd5acfd6", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Index(['class_', 'data'], dtype='object')" ] }, "metadata": {}, "execution_count": 122 } ] }, { "cell_type": "code", "source": [ "data = list(dataset.data)\n", "label = dataset.class_" ], "metadata": { "id": "DZzugpjusirx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\") #Tokenizer\n", "inputs = tokenizer(data, padding=True, truncation=True, return_tensors='tf') #Tokenized text" ], "metadata": { "id": "qpGP3JkNs7Dn" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "len(inputs['input_ids']) == len(label)" ], "metadata": { "id": "8eeYQv02w_l9", "outputId": "464101d2-979d-4bea-d711-9561d3a19e32", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 125 } ] }, { "cell_type": "code", "source": [ "dataset=tf.data.Dataset.from_tensor_slices((dict(inputs), label)) #Create a tensorflow dataset\n", "#train test split, we use 10% of the data for validation\n", "val_data_size=int(0.2*len(label))\n", "val_ds=dataset.take(val_data_size).batch(batch_size, drop_remainder=True) \n", "train_ds=dataset.skip(val_data_size).batch(batch_size, drop_remainder=True)\n", "train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)" ], "metadata": { "id": "U9ScngQYvrt3" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "with tpu_strategy.scope():\n", " model = TFAutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", " model.compile(\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5, clipnorm=1.),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.metrics.SparseCategoricalAccuracy()],\n", " )\n", " \n", "history=model.fit(train_ds, validation_data=val_ds, epochs=5, verbose=1)\n" ], "metadata": { "id": "dhTNraaNyhCS", "outputId": "c1d121a4-def4-4e54-8d66-2c20569ec523", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForSequenceClassification: ['activation_13', 'vocab_layer_norm', 'vocab_projector', 'vocab_transform']\n", "- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier', 'classifier', 'dropout_39']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/5\n", "122/122 [==============================] - 106s 457ms/step - loss: 0.1616 - sparse_categorical_accuracy: 0.9687 - val_loss: 0.0124 - val_sparse_categorical_accuracy: 0.9993\n", "Epoch 2/5\n", "122/122 [==============================] - 50s 414ms/step - loss: 0.0096 - sparse_categorical_accuracy: 0.9991 - val_loss: 0.0042 - val_sparse_categorical_accuracy: 0.9997\n", "Epoch 3/5\n", "122/122 [==============================] - 50s 414ms/step - loss: 0.0039 - sparse_categorical_accuracy: 0.9996 - val_loss: 0.0030 - val_sparse_categorical_accuracy: 0.9996\n", "Epoch 4/5\n", "122/122 [==============================] - 51s 414ms/step - loss: 0.0019 - sparse_categorical_accuracy: 0.9998 - val_loss: 0.0014 - val_sparse_categorical_accuracy: 0.9999\n", "Epoch 5/5\n", "122/122 [==============================] - 51s 415ms/step - loss: 0.0016 - sparse_categorical_accuracy: 0.9997 - val_loss: 0.0036 - val_sparse_categorical_accuracy: 0.9993\n" ] } ] }, { "cell_type": "code", "source": [ "model.save_weights('./distill_bert_fake_news_saved_weights_epoch_5_0-8.h5')" ], "metadata": { "id": "LuTypVNSz8oA" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model.evaluate(val_ds)" ], "metadata": { "id": "niOM7TYi__N6", "outputId": "6fea2853-4c43-44fb-8b6b-52a4eca4ef8c", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "30/30 [==============================] - 5s 121ms/step - loss: 0.0036 - sparse_categorical_accuracy: 0.9993\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[0.0036058793775737286, 0.9993489980697632]" ] }, "metadata": {}, "execution_count": 129 } ] }, { "cell_type": "code", "source": [ "output = model.predict(val_ds)" ], "metadata": { "id": "p8IY5uTcATkX" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "val_labels = list(val_ds.as_numpy_iterator())" ], "metadata": { "id": "YD1sez_xIdh9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "val_true = []\n", "for i in range(len(val_labels)):\n", " val_true.append(val_labels[i][1])" ], "metadata": { "id": "kbtefx9QIvOz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "val_text = []\n", "for i in range(len(val_labels)):\n", " val_text.append(val_labels[i][0]['input_ids'])" ], "metadata": { "id": "-5DN3sUla3zK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import numpy as np\n", "val_true_ = np.concatenate(val_true)" ], "metadata": { "id": "jNTnsNwnJc0Q" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "val_text_ = np.concatenate(val_text)" ], "metadata": { "id": "6bZo_FDkbbaY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from scipy.special import softmax" ], "metadata": { "id": "7LuMPSGWKzdF" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "output_softmax = softmax(output.logits,axis=1)" ], "metadata": { "id": "f-YBX5JWKmsU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "output_values = np.argmax(output_softmax,axis=1)" ], "metadata": { "id": "QQCEkX9eK5Kl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.metrics import roc_curve\n", "fpr, tpr, thresholds = roc_curve(val_true_, output_softmax[:,1])\n", "from sklearn.metrics import roc_auc_score\n", "auc = roc_auc_score(val_true_, output_softmax[:,1])\n", "print('AUC: %.3f' % auc)" ], "metadata": { "id": "yIOR8wfAKgMs", "outputId": "05f98fac-c218-4180-f8a8-54540120ea9f", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "AUC: 1.000\n" ] } ] }, { "cell_type": "code", "source": [ "import matplotlib.pyplot as plt\n", "\n", "#create ROC curve\n", "plt.plot(fpr,tpr)\n", "plt.ylabel('True Positive Rate')\n", "plt.xlabel('False Positive Rate')\n", "plt.show()" ], "metadata": { "id": "dd711OzxLpoi", "outputId": "5ac29de9-6ebf-4705-c00a-779ebe8ad504", "colab": { "base_uri": "https://localhost:8080/", "height": 279 } }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXBklEQVR4nO3de7RedX3n8fdHLoLcnJo4o0AM2niJlwo9gxdGxWIVUaEdkcvIWFqWaVWsHdA1tLjQQWvHUu0qLa1GZaGOCkirTTWaaS2IowKJgkBCcaV4IVyGFBmqxQvod/7Y+9Snh3N5kpy9j+fs92uts86+/J79fHdOks/57f3s3y9VhSRpuB6y0AVIkhaWQSBJA2cQSNLAGQSSNHAGgSQN3O4LXcCOWrZsWa1cuXKhy5CkReUrX/nKP1XV8un2LbogWLlyJZs2bVroMiRpUUnyrZn2eWlIkgbOIJCkgTMIJGngDAJJGjiDQJIGrrMgSHJhkruS3DjD/iQ5P8nWJNcnOayrWiRJM+uyR3ARcPQs+18MrGq/1gB/0WEtkqQZdPYcQVVdmWTlLE2OAz5UzTjYVyV5eJJHVdUdXdTz0au/zV9fd1sXh5akXqx+9P685WVPnvfjLuQ9ggOBW0fWt7XbHiTJmiSbkmzavn37Tr3ZX193G1vu+Oedeq0kLWWL4sniqloLrAWYmJjY6Zl0Vj9qfy75zWfNW12StBQsZI/gNuDgkfWD2m2SpB4tZBCsA17VfnromcC9Xd0fkCTNrLNLQ0k+BhwJLEuyDXgLsAdAVb0HWA8cA2wF7gN+vataJEkz6/JTQyfPsb+A13X1/pKk8fhksSQNnEEgSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxBI0sAZBJI0cAaBJA2cQSBJA2cQSNLAGQSSNHAGgSQNnEEgSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxBI0sAZBJI0cAaBJA2cQSBJA2cQSNLAGQSSNHAGgSQNnEEgSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxBI0sB1GgRJjk5yc5KtSc6aZv+KJJcnuTbJ9UmO6bIeSdKDdRYESXYDLgBeDKwGTk6yekqzNwOXVtWhwEnAn3dVjyRpel32CA4HtlbVLVX1I+Bi4LgpbQrYv10+ALi9w3okSdPoMggOBG4dWd/Wbhv1VuCUJNuA9cDrpztQkjVJNiXZtH379i5qlaTBWuibxScDF1XVQcAxwIeTPKimqlpbVRNVNbF8+fLei5SkpazLILgNOHhk/aB226jTgEsBqurLwF7Asg5rkiRN0WUQbARWJTkkyZ40N4PXTWnzbeAogCRPogkCr/1IUo86C4KqegA4HdgA3ETz6aDNSc5Ncmzb7Ezg1Um+BnwMOLWqqquaJEkPtnuXB6+q9TQ3gUe3nTOyvAU4ossaJEmzW+ibxZKkBWYQSNLAGQSSNHAGgSQNnEEgSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxBI0sCNHQRJHtZlIZKkhTFnECR5dpItwD+067+QxCklJWmJGKdH8MfAi4C7Aarqa8BzuyxKktSfsS4NVdWtUzb9uINaJEkLYJxhqG9N8mygkuwBvIFmfgFJ0hIwTo/gt4DX0Uw8fxvwdOC1XRYlSerPOD2CJ1TVK0c3JDkC+GI3JUmS+jROj+BPx9wmSVqEZuwRJHkW8GxgeZIzRnbtD+zWdWGSpH7MdmloT2Dfts1+I9v/GTi+y6IkSf2ZMQiq6vPA55NcVFXf6rEmSVKPxrlZfF+S84AnA3tNbqyqX+qsKklSb8a5WfwRmuElDgH+B/BNYGOHNUmSejROEDyiqj4A3F9Vn6+q3wDsDUjSEjHOpaH72+93JHkJcDvwc92VJEnq0zhB8PYkBwBn0jw/sD/wO51WJUnqzZxBUFWfahfvBZ4P//pksSRpCZjtgbLdgBNoxhj6bFXdmOSlwO8BewOH9lOiJKlLs/UIPgAcDFwDnJ/kdmACOKuqPtlHcZKk7s0WBBPA06rqJ0n2Au4EHldVd/dTmiSpD7N9fPRHVfUTgKr6AXDLjoZAkqOT3Jxka5KzZmhzQpItSTYn+eiOHF+StOtm6xE8Mcn17XKAx7XrAaqqnjbbgdt7DBcAvwxsAzYmWVdVW0barAJ+Fziiqu5J8shdOBdJ0k6YLQietIvHPhzYWlW3ACS5GDgO2DLS5tXABVV1D0BV3bWL7ylJ2kGzDTq3qwPNHQiMznW8DXjGlDaPB0jyRZqhrd9aVZ+deqAka4A1ACtWrNjFsiRJo8aavL5DuwOrgCOBk4H3JXn41EZVtbaqJqpqYvny5T2XKElLW5dBcBvNx08nHdRuG7UNWFdV91fVN4Cv0wSDJKknYwVBkr2TPGEHj70RWJXkkCR7AicB66a0+SRNb4Aky2guFd2yg+8jSdoFcwZBkpcB1wGfbdefnmTqf+gPUlUPAKcDG4CbgEuranOSc5Mc2zbbANydZAtwOfAmn1OQpH6NM+jcW2k+AXQFQFVdl+SQcQ5eVeuB9VO2nTOyXMAZ7ZckaQGMc2no/qq6d8q26qIYSVL/xukRbE7yX4Dd2gfAfhv4UrdlSZL6Mk6P4PU08xX/EPgozXDUzkcgSUvEOD2CJ1bV2cDZXRcjSerfOD2CdyW5Kcnbkjyl84okSb2aMwiq6vk0M5NtB96b5IYkb+68MklSL8Z6oKyq7qyq84Hfonmm4Jw5XiJJWiTGeaDsSUnemuQGmsnrv0QzXIQkaQkY52bxhcAlwIuq6vaO65Ek9WzOIKiqZ/VRiCRpYcwYBEkuraoT2ktCo08SjzVDmSRpcZitR/CG9vtL+yhEkrQwZrxZXFV3tIuvrapvjX4Br+2nPElS18b5+OgvT7PtxfNdiCRpYcx2j+A1NL/5PzbJ9SO79gO+2HVhkqR+zHaP4KPAZ4A/AM4a2f7dqvpOp1VJknozWxBUVX0zyeum7kjyc4aBJC0Nc/UIXgp8hebjoxnZV8BjO6xLktSTGYOgql7afh9rWkpJ0uI0zlhDRyTZp10+Jcm7k6zovjRJUh/G+fjoXwD3JfkF4EzgH4EPd1qVJKk34wTBA1VVwHHAn1XVBTQfIZUkLQHjjD763SS/C/xX4DlJHgLs0W1ZkqS+jNMjOJFm4vrfqKo7aeYiOK/TqiRJvRlnqso7gY8AByR5KfCDqvpQ55VJknoxzqeGTgCuAV4BnABcneT4rguTJPVjnHsEZwP/saruAkiyHPg74LIuC5Mk9WOcewQPmQyB1t1jvk6StAiM0yP4bJINwMfa9ROB9d2VJEnq0zhzFr8pyX8G/lO7aW1VfaLbsiRJfZltPoJVwB8BjwNuAN5YVbf1VZgkqR+zXeu/EPgU8HKaEUj/dEcPnuToJDcn2ZrkrFnavTxJJZnY0feQJO2a2S4N7VdV72uXb07y1R05cJLdgAtoprrcBmxMsq6qtkxptx/wBuDqHTm+JGl+zBYEeyU5lJ/OQ7D36HpVzRUMhwNbq+oWgCQX04xXtGVKu7cB7wTetIO1S5LmwWxBcAfw7pH1O0fWC/ilOY59IHDryPo24BmjDZIcBhxcVZ9OMmMQJFkDrAFYscIRsCVpPs02Mc3zu3zjdvC6dwOnztW2qtYCawEmJiaqy7okaWi6fDDsNuDgkfWD2m2T9gOeAlyR5JvAM4F13jCWpH51GQQbgVVJDkmyJ3ASsG5yZ1XdW1XLqmplVa0ErgKOrapNHdYkSZqisyCoqgeA04ENwE3ApVW1Ocm5SY7t6n0lSTtmzieLkwR4JfDYqjq3na/4P1TVNXO9tqrWM2U4iqo6Z4a2R45VsSRpXo3TI/hz4FnAye36d2meD5AkLQHjDDr3jKo6LMm1AFV1T3vNX5K0BIzTI7i/fUq44F/nI/hJp1VJknozThCcD3wCeGSS3wf+D/COTquSJPVmnGGoP5LkK8BRNMNL/EpV3dR5ZZKkXozzqaEVwH3A34xuq6pvd1mYJKkf49ws/jTN/YEAewGHADcDT+6wLklST8a5NPTU0fV2oLjXdlaRJKlXO/xkcTv89DPmbChJWhTGuUdwxsjqQ4DDgNs7q0iS1Ktx7hHsN7L8AM09g7/sphxJUt9mDYL2QbL9quqNPdUjSerZjPcIkuxeVT8GjuixHklSz2brEVxDcz/guiTrgI8D/zK5s6r+quPaJEk9GOcewV7A3TRzFE8+T1CAQSBJS8BsQfDI9hNDN/LTAJjkvMGStETMFgS7AfvybwNgkkEgSUvEbEFwR1Wd21slkqQFMduTxdP1BCRJS8xsQXBUb1VIkhbMjEFQVd/psxBJ0sLY4UHnJElLi0EgSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxBI0sAZBJI0cAaBJA1cp0GQ5OgkNyfZmuSsafafkWRLkuuTfC7JY7qsR5L0YJ0FQTvf8QXAi4HVwMlJVk9pdi0wUVVPAy4D/rCreiRJ0+uyR3A4sLWqbqmqHwEXA8eNNqiqy6vqvnb1KuCgDuuRJE2jyyA4ELh1ZH1bu20mpwGfmW5HkjVJNiXZtH379nksUZL0M3GzOMkpwARw3nT7q2ptVU1U1cTy5cv7LU6SlrhxJq/fWbcBB4+sH9Ru+zeSvAA4G3heVf2ww3okSdPoskewEViV5JAkewInAetGGyQ5FHgvcGxV3dVhLZKkGXQWBFX1AHA6sAG4Cbi0qjYnOTfJsW2z84B9gY8nuS7JuhkOJ0nqSJeXhqiq9cD6KdvOGVl+QZfvL0ma28/EzWJJ0sIxCCRp4AwCSRo4g0CSBs4gkKSBMwgkaeAMAkkaOINAkgbOIJCkgTMIJGngDAJJGjiDQJIGziCQpIEzCCRp4AwCSRo4g0CSBs4gkKSBMwgkaeAMAkkaOINAkgbOIJCkgTMIJGngDAJJGjiDQJIGziCQpIEzCCRp4AwCSRo4g0CSBs4gkKSBMwgkaeAMAkkauE6DIMnRSW5OsjXJWdPsf2iSS9r9VydZ2WU9kqQH6ywIkuwGXAC8GFgNnJxk9ZRmpwH3VNXPA38MvLOreiRJ0+uyR3A4sLWqbqmqHwEXA8dNaXMc8MF2+TLgqCTpsCZJ0hS7d3jsA4FbR9a3Ac+YqU1VPZDkXuARwD+NNkqyBlgDsGLFip0qZvWj99+p10nSUtdlEMybqloLrAWYmJionTnGW1725HmtSZKWii4vDd0GHDyyflC7bdo2SXYHDgDu7rAmSdIUXQbBRmBVkkOS7AmcBKyb0mYd8Gvt8vHA31fVTv3GL0naOZ1dGmqv+Z8ObAB2Ay6sqs1JzgU2VdU64APAh5NsBb5DExaSpB51eo+gqtYD66dsO2dk+QfAK7qsQZI0O58slqSBMwgkaeAMAkkaOINAkgYui+3Tmkm2A9/ayZcvY8pTywPgOQ+D5zwMu3LOj6mq5dPtWHRBsCuSbKqqiYWuo0+e8zB4zsPQ1Tl7aUiSBs4gkKSBG1oQrF3oAhaA5zwMnvMwdHLOg7pHIEl6sKH1CCRJUxgEkjRwSzIIkhyd5OYkW5OcNc3+hya5pN1/dZKV/Vc5v8Y45zOSbElyfZLPJXnMQtQ5n+Y655F2L09SSRb9Rw3HOeckJ7Q/681JPtp3jfNtjL/bK5JcnuTa9u/3MQtR53xJcmGSu5LcOMP+JDm//fO4Pslhu/ymVbWkvmiGvP5H4LHAnsDXgNVT2rwWeE+7fBJwyULX3cM5Px94WLv8miGcc9tuP+BK4CpgYqHr7uHnvAq4Fvh37fojF7ruHs55LfCadnk18M2FrnsXz/m5wGHAjTPsPwb4DBDgmcDVu/qeS7FHcDiwtapuqaofARcDx01pcxzwwXb5MuCoJOmxxvk25zlX1eVVdV+7ehXNjHGL2Tg/Z4C3Ae8EftBncR0Z55xfDVxQVfcAVNVdPdc438Y55wImJyU/ALi9x/rmXVVdSTM/y0yOAz5UjauAhyd51K6851IMggOBW0fWt7Xbpm1TVQ8A9wKP6KW6boxzzqNOo/mNYjGb85zbLvPBVfXpPgvr0Dg/58cDj0/yxSRXJTm6t+q6Mc45vxU4Jck2mvlPXt9PaQtmR/+9z2lRTF6v+ZPkFGACeN5C19KlJA8B3g2cusCl9G13mstDR9L0+q5M8tSq+n8LWlW3TgYuqqp3JXkWzayHT6mqnyx0YYvFUuwR3AYcPLJ+ULtt2jZJdqfpTt7dS3XdGOecSfIC4Gzg2Kr6YU+1dWWuc94PeApwRZJv0lxLXbfIbxiP83PeBqyrqvur6hvA12mCYbEa55xPAy4FqKovA3vRDM62VI31731HLMUg2AisSnJIkj1pbgavm9JmHfBr7fLxwN9XexdmkZrznJMcCryXJgQW+3VjmOOcq+reqlpWVSuraiXNfZFjq2rTwpQ7L8b5u/1Jmt4ASZbRXCq6pc8i59k45/xt4CiAJE+iCYLtvVbZr3XAq9pPDz0TuLeq7tiVAy65S0NV9UCS04ENNJ84uLCqNic5F9hUVeuAD9B0H7fS3JQ5aeEq3nVjnvN5wL7Ax9v74t+uqmMXrOhdNOY5LyljnvMG4IVJtgA/Bt5UVYu2tzvmOZ8JvC/Jf6O5cXzqYv7FLsnHaMJ8WXvf4y3AHgBV9R6a+yDHAFuB+4Bf3+X3XMR/XpKkebAULw1JknaAQSBJA2cQSNLAGQSSNHAGgSQNnEGgn0lJfpzkupGvlbO0/d48vN9FSb7RvtdX2ydUd/QY70+yul3+vSn7vrSrNbbHmfxzuTHJ3yR5+Bztn77YR+NU9/z4qH4mJfleVe07321nOcZFwKeq6rIkLwT+qKqetgvH2+Wa5jpukg8CX6+q35+l/ak0o66ePt+1aOmwR6BFIcm+7TwKX01yQ5IHjTSa5FFJrhz5jfk57fYXJvly+9qPJ5nrP+grgZ9vX3tGe6wbk/xOu22fJJ9O8rV2+4nt9iuSTCT5n8DebR0fafd9r/1+cZKXjNR8UZLjk+yW5LwkG9sx5n9zjD+WL9MONpbk8PYcr03ypSRPaJ/EPRc4sa3lxLb2C5Nc07adbsRWDc1Cj73tl1/TfdE8FXtd+/UJmqfg92/3LaN5qnKyR/u99vuZwNnt8m404w0to/mPfZ92+38Hzpnm/S4Cjm+XXwFcDfwicAOwD81T2ZuBQ4GXA+8bee0B7fcraOc8mKxppM1kjb8KfLBd3pNmFMm9gTXAm9vtDwU2AYdMU+f3Rs7v48DR7fr+wO7t8guAv2yXTwX+bOT17wBOaZcfTjMW0T4L/fP2a2G/ltwQE1oyvl9VT59cSbIH8I4kzwV+QvOb8L8H7hx5zUbgwrbtJ6vquiTPo5ms5Ivt0Bp70vwmPZ3zkryZZpya02jGr/lEVf1LW8NfAc8BPgu8K8k7aS4nfWEHzuszwJ8keShwNHBlVX2/vRz1tCTHt+0OoBks7htTXr93kuva878J+NuR9h9MsopmmIU9Znj/FwLHJnlju74XsKI9lgbKINBi8UpgOfCLVXV/mhFF9xptUFVXtkHxEuCiJO8G7gH+tqpOHuM93lRVl02uJDlqukZV9fU0cx0cA7w9yeeq6txxTqKqfpDkCuBFwIk0E61AM9vU66tqwxyH+H5VPT3Jw2jG33kdcD7NBDyXV9WvtjfWr5jh9QFeXlU3j1OvhsF7BFosDgDuakPg+cCD5lxOMw/z/62q9wHvp5nu7yrgiCST1/z3SfL4Md/zC8CvJHlYkn1oLut8Icmjgfuq6n/RDOY33Zyx97c9k+lcQjNQ2GTvApr/1F8z+Zokj2/fc1rVzDb328CZ+elQ6pNDEZ860vS7NJfIJm0AXp+2e5RmVFoNnEGgxeIjwESSG4BXAf8wTZsjga8luZbmt+0/qartNP8xfizJ9TSXhZ44zhtW1Vdp7h1cQ3PP4P1VdS3wVOCa9hLNW4C3T/PytcD1kzeLp/jfNBMD/V010y9CE1xbgK+mmbT8vczRY29ruZ5mYpY/BP6gPffR110OrJ68WUzTc9ijrW1zu66B8+OjkjRw9ggkaeAMAkkaOINAkgbOIJCkgTMIJGngDAJJGjiDQJIG7v8DzNYU2yupHz8AAAAASUVORK5CYII=\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "source": [ "bool_ = output_values == val_true_" ], "metadata": { "id": "PdJ6rBPRN14W" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "bool_ = list(bool_)" ], "metadata": { "id": "Lv68D_g5N9G7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "indices = [i for i, x in enumerate(bool_) if x == False]" ], "metadata": { "id": "-0N0Pg2nOam5" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "indices = [1339, 2615, 6600, 7026] #wrong prediction" ], "metadata": { "id": "QxxBij34OJC2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "for i in indices: # 0 -> True News 1-> False News\n", " print('Sample ',i)\n", " print('Predicted Value: ',np.argmax(softmax(model.predict(val_text_[i,:].reshape(1,512)).logits,axis=1),axis=1),' True Value: ',val_true_[i])\n" ], "metadata": { "id": "_b9RiItGcKI1", "outputId": "0badb610-f732-445a-b53b-8bec9d297463", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Sample 1339\n", "Predicted Value: [0] True Value: 1\n", "Sample 2615\n", "Predicted Value: [0] True Value: 1\n", "Sample 6600\n", "Predicted Value: [0] True Value: 1\n", "Sample 7026\n", "Predicted Value: [0] True Value: 1\n" ] } ] }, { "cell_type": "code", "source": [ "print('Wrong Prediction')\n", "for i in indices:\n", " print(tokenizer.decode(val_text_[i]))" ], "metadata": { "id": "ruoYIFXkipMX", "outputId": "fe11069a-c0d3-42e9-bbcb-9d21bc538399", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Wrong Prediction\n", "[CLS] trump tells state department to make cut more than 50 % of funding to u. n. president trump s administration has told the state department to cut more than 50 percent of u. s. funding to united nations programs, foreign policy reported. the push for the drastic reductions comes as the white house is scheduled to release its 2018 topline budget proposal thursday, which is expected to include a 37 percent cut to the state department and u. s. agency for international development budgets. it s not clear if trump s budget plan, from the office of management and budget, would reflect the full extent of trump s proposed cuts to the u. n. richard gowan, a u. n. expert at the european council on foreign relations, said the alterations would spark chaos if true. [ it would ] leave a gaping hole that other big donors would struggle to fill, he told fp, pointing to how the u. s. provided $ 1. 5 billion of the u. n. refugee agency s $ 4 billion budget last year. via : the hill [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n", "[CLS] u. s. news and world report publishes list of top 10 most popular nations where refugees want to live more than 21, 000 people from all regions of the world participated in the best countries survey, in which they assessed how closely they associated 80 countries with specific characteristics. four of these economically stable, good job market, income equality and is a place i would live were included in the best countries to be an immigrant ranking. countries also were scored in relation to others on the share of migrants in their population ; the amount of remittances the migrants they host sent home ; and graded on a united nations assessment of integration measures provided for immigrants, such as language training and transfers of job certifications, and the rationale behind current integration policies. [ lol! not assimilation! ed ] note how important remittances are! those are dollars sent out of the host economy and i will bet a buck that any economic study that seeks to justify migration benefits to a country, never factors in how much of the migrants earnings ( or their welfare payments! ) are sent out of the country ( and thus lost to the host country s economy! ). ann corcoran refugee resettlement watch1. sweden ( this, coincidentally, is the country i have for a long time ranked as # 1 to fall to the islamists! ) 2. canada ( just be sure our northern border is fortified! ) 3. switzerland4. australia5. germany6. norway7. us8. netherlands9. finland10. denmarknotice that arab ( mostly muslim ) countries are not a desired destination. gee, why is that? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n", "[CLS] energy department to close office of international climate and technology in response to the u. s. withdrawing from the paris climate agreement earlier this month, the energy department is shutting down the office of international climate and technology, a department that works with other countries to develop clean energy technology. an agency spokesman tried to justify the energy department shutting down the office by stating that the doe is looking for ways to consolidate the many duplicative programs that currently exist within doe, thus the office of international climate and technology is getting the chop. the 11 - person office has been in operation since 2010, operating as a means for the u. s. to work with international partners on energy sector technology in an effort to reduce greenhouse gases. the employees of the office of international climate and technology also play a large part in the clean energy ministerial, a conference for high - polluting nations to focus on making the energy sector greener. doe spokeswoman shaylyn hynes said there are numerous international offices within the energy department that could take on the work of the office of international climate and technology, however, she failed to acknowledge whether one actually would. the office of energy efficiency and renewable energy ( eere ) has an international affairs team, while the international affairs office has a renewables team, hynes said. the department is looking for ways to eliminate this kind of unnecessary duplication just like any responsible american business would. the closing of this particular office is most likely a direct result of trump s 2018 budget proposal, which slashes funding for both the doe and the environmental protection agency, particularly cuts for climate change initiatives and research efforts. naturally, environmentalists are horrified by the news of the international office s closure. willfully ignoring the climate crisis is recklessly and unnecessarily dangerous for families and communities across the country, and it s clear that trump will stop at nothing to completely isolate the united states and irreparably damage our reputation with the rest of the world, said john coequyt, the global climate policy director at the sierra club. ignorance is not diplomacy, and if trump were acting like a leader, he would know that. hynes responded by saying that the trump administration is not bringing an end to its clean energy efforts, drawing particular attention to energy secretary rick perry s support for carbon capture storage and nuclear energy efforts at a recent clean energy ministerial. featured image via kevin frayer / getty images [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n", "[CLS] democrats go on the attack against voter id laws virginia s voter id law is being challenged in court and with the loss of a conservative majority on the supreme court, this might just be the first domino to fall, signaling the end of wide - spread voter suppression by republicans : ( reuters ) a virginia law requiring voters to show photo identification goes on trial in federal court on monday, with democratic officials claiming it is discriminatory and aimed at keeping party voters from casting ballots. defenders of the 2013 virginia law say that it is aimed at preventing voter fraud. the trial, in u. s. district court in richmond, virginia, is one of several voting rights legal battles in process as democrats and republicans square off ahead of november s presidential election. the democratic party of virginia and two party activists are suing the virginia state board of elections and want judge henry hudson to strike down the law. the problem for the right is that the premise behind these laws has always been a lie. republicans only became concerned with voter fraud when it became clear that their white hegemony was in immediate danger. adding to this problem was the over - the - top way republicans attacked people s ability to vote. if voter id had been a stand alone law, then republicans might have been able to maintain the fiction that it was just about voter fraud. however, they also cut voting hours, the number of polling places ( but only in democratic - leaning areas ), made the ids difficult to get and, of course, kept bragging about how it would hurt the democrats. and just to be absolutely clear that voter id laws are bullshit, here s my favorite quote from the new york times ( 4 / 12 / 2007 ) : five years after the bush administration began a crackdown on voter fraud, the justice department has turned up virtually no evidence of any organized effort to skew federal elections, according to court records and interviews. although republican activists have repeatedly said fraud is so widespread that it has corrupted the political process and, possibly, cost the party election victories, about 120 people have been charged and 86 convicted as of last year. widespread voter fraud does not exist. the only way for a court to uphold voter id is to willfully turn a blind eye to the glaring pattern of voter suppression republicans have openly laid out. if you re wondering why they were so brazen, it s because they were quite confidant that the supreme court would back their partisan attack on democracy. after all, they gutted the voting rights act by claiming, in all seriousness, [SEP]\n" ] } ] }, { "cell_type": "markdown", "source": [ "#The model classifies only 4 data samples predicted as true news though they are fake news. It is performing well on the test set though it is not very large (20% of the data). Model hyperparameter tuning could be done to get more accurate results on larger dataset. Particularly automated hyperparameter tuning using libraries like optuna." ], "metadata": { "id": "QzeGqPGWj2sz" } }, { "cell_type": "markdown", "source": [ "**Write up**: \n", "* Link to the model on Hugging Face Hub: \n", "* Include some examples of misclassified news articles. Please explain what you might do to improve your model's performance on these news articles in the future (you do not need to impelement these suggestions)" ], "metadata": { "id": "kpInVUMLyJ24" } }, { "cell_type": "markdown", "metadata": { "id": "jTfHpo6BOmE8" }, "source": [ "# 3. Deep RL / Robotics" ] }, { "cell_type": "markdown", "metadata": { "id": "saB64bbTXWgZ" }, "source": [ "**RL for Classical Control:** Using any of the [classical control](https://github.com/openai/gym/blob/master/docs/environments.md#classic-control) environments from OpenAI's `gym`, implement a deep NN that learns an optimal policy which maximizes the reward of the environment.\n", "\n", "* Describe the NN you implemented and the behavior you observe from the agent as the model converges (or diverges).\n", "* Plot the reward as a function of steps (or Epochs).\n", "Compare your results to a random agent.\n", "* Discuss whether you think your model has learned the optimal policy and potential methods for improving it and/or where it might fail.\n", "* (Optional) [Upload the the model to the Hugging Face Hub](https://huggingface.co/docs/hub/adding-a-model), and add a link to your model below.\n", "\n", "\n", "You may use any frameworks you like, but you must implement your NN on your own (no pre-defined/trained models like [`stable_baselines`](https://stable-baselines.readthedocs.io/en/master/)).\n", "\n", "You may use any simulator other than `gym` _however_:\n", "* The environment has to be similar to the classical control environments (or more complex like [`robosuite`](https://github.com/ARISE-Initiative/robosuite)).\n", "* You cannot choose a game/Atari/text based environment. The purpose of this challenge is to demonstrate an understanding of basic kinematic/dynamic systems." ] }, { "cell_type": "code", "source": [ "### WRITE YOUR CODE TO TRAIN THE MODEL HERE" ], "metadata": { "id": "CUhkTcoeynVv" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Write up**: \n", "* (Optional) link to the model on Hugging Face Hub: \n", "* Discuss whether you think your model has learned the optimal policy and potential methods for improving it and/or where it might fail." ], "metadata": { "id": "bWllPZhJyotg" } }, { "cell_type": "markdown", "metadata": { "id": "rbrRbrISa5J_" }, "source": [ "# 4. Theory / Linear Algebra " ] }, { "cell_type": "markdown", "metadata": { "id": "KFkLRCzTXTzL" }, "source": [ "**Implement Contrastive PCA** Read [this paper](https://www.nature.com/articles/s41467-018-04608-8) and implement contrastive PCA in Python.\n", "\n", "* First, please discuss what kind of dataset this would make sense to use this method on\n", "* Implement the method in Python (do not use previous implementations of the method if they already exist)\n", "* Then create a synthetic dataset and apply the method to the synthetic data. Compare with standard PCA.\n" ] }, { "cell_type": "markdown", "source": [ "**Write up**: Discuss what kind of dataset it would make sense to use Contrastive PCA" ], "metadata": { "id": "TpyqWl-ly0wy" } }, { "cell_type": "code", "source": [ "### WRITE YOUR CODE HERE" ], "metadata": { "id": "1CQzUSfQywRk" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# 5. Systems" ], "metadata": { "id": "dlqmZS5Hy6q-" } }, { "cell_type": "markdown", "source": [ "**Inference on the edge**: Measure the inference times in various computationally-constrained settings\n", "\n", "* Pick a few different speech detection models (we suggest looking at models on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=downloads))\n", "* Simulate different memory constraints and CPU allocations that are realistic for edge devices that might run such models, such as smart speakers or microcontrollers, and measure what is the average inference time of the models under these conditions \n", "* How does the inference time vary with (1) choice of model (2) available system memory (3) available CPU (4) size of input?\n", "\n", "Are there any surprising discoveries? (Note that this coding challenge is fairly open-ended, so we will be considering the amount of effort invested in discovering something interesting here)." ], "metadata": { "id": "QW_eiDFw1QKm" } }, { "cell_type": "code", "source": [ "### WRITE YOUR CODE HERE" ], "metadata": { "id": "OYp94wLP1kWJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Write up**: What surprising discoveries do you see?" ], "metadata": { "id": "yoHmutWx2jer" } } ] }