diff --git "a/eleventh_doctor_beta2.ipynb" "b/eleventh_doctor_beta2.ipynb"
new file mode 100644--- /dev/null
+++ "b/eleventh_doctor_beta2.ipynb"
@@ -0,0 +1,4176 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "eleventh-doctor-beta.ipynb",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "a49e5fd0d85444a3aa9f786455ca8770": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_73e8d052a86647919649a367aa773c8e",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_40760124752846209e61177280a005bd",
+ "IPY_MODEL_eae3f41495884830818311e51920c956",
+ "IPY_MODEL_5d1a116b987549d780ee25723f83d45a"
+ ]
+ }
+ },
+ "73e8d052a86647919649a367aa773c8e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "40760124752846209e61177280a005bd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_46f7e33281354ef488945f5f1cfe4c06",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Epoch: 100%",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_e987dfed8c624717b5ae2054cce74f05"
+ }
+ },
+ "eae3f41495884830818311e51920c956": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_c49e73bdfe544de0ae62034cef7eb0da",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 4,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 4,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_52e599d90ccd44938d982310fb7e4341"
+ }
+ },
+ "5d1a116b987549d780ee25723f83d45a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_c01768976adf465ebfad5c3eedfe1d58",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 4/4 [00:11<00:00, 2.87s/it]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_27fb7c7e261b4a3b9656a37b1fcde71a"
+ }
+ },
+ "46f7e33281354ef488945f5f1cfe4c06": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "e987dfed8c624717b5ae2054cce74f05": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "c49e73bdfe544de0ae62034cef7eb0da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "52e599d90ccd44938d982310fb7e4341": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "c01768976adf465ebfad5c3eedfe1d58": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "27fb7c7e261b4a3b9656a37b1fcde71a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "1c82670ef31346eb97dff63429fd522f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_a8c2fda5e0be4c638919b4ca1007dea3",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_5e60bbde81ed452fa0c8d7094d98b052",
+ "IPY_MODEL_3ac055433ca94c2ebe9f8b44e38be5e0",
+ "IPY_MODEL_67e70ffbe152488fb036968be105a368"
+ ]
+ }
+ },
+ "a8c2fda5e0be4c638919b4ca1007dea3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "5e60bbde81ed452fa0c8d7094d98b052": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_bbeb9a01f5bb4aebba239db555f4b16b",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Iteration: 100%",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_771602cc4d9444e7ab0d20438639cddd"
+ }
+ },
+ "3ac055433ca94c2ebe9f8b44e38be5e0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_f154ee2be8a044b3aeeb0e904411ffbd",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 5,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 5,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_a3a6841089054f1cbc31f638424674b3"
+ }
+ },
+ "67e70ffbe152488fb036968be105a368": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_41061818a9c94956a7d1cd129028d805",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 5/5 [00:02<00:00, 1.89it/s]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_bab64ef3864248018e9476bc8c4018f4"
+ }
+ },
+ "bbeb9a01f5bb4aebba239db555f4b16b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "771602cc4d9444e7ab0d20438639cddd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "f154ee2be8a044b3aeeb0e904411ffbd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "a3a6841089054f1cbc31f638424674b3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "41061818a9c94956a7d1cd129028d805": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "bab64ef3864248018e9476bc8c4018f4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "f3fa20cd1c40453bb17b2f109607e1bf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_97b3a5270a014515bbc712b44dba38a0",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_afe3a3438fb145d4a015fdb0709e3156",
+ "IPY_MODEL_6d6078316fe54c9e83a3c3a35a1169fc",
+ "IPY_MODEL_4700a281e7d347db8a58c6f181706b54"
+ ]
+ }
+ },
+ "97b3a5270a014515bbc712b44dba38a0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "afe3a3438fb145d4a015fdb0709e3156": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_84f8bdfeb6bf4bb7ba4585eba47a7092",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Iteration: 100%",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_9262f880b64a4abb80013f6997901bcb"
+ }
+ },
+ "6d6078316fe54c9e83a3c3a35a1169fc": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_386289f0bf56453484a6637d3263da4c",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 5,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 5,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_8366904443f649928aa9cfd915cd938a"
+ }
+ },
+ "4700a281e7d347db8a58c6f181706b54": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_80f977590ae94733a8a8552241c12e3b",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 5/5 [00:02<00:00, 1.79it/s]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_3ff5971c055144a3b81190d199ffe3de"
+ }
+ },
+ "84f8bdfeb6bf4bb7ba4585eba47a7092": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "9262f880b64a4abb80013f6997901bcb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "386289f0bf56453484a6637d3263da4c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "8366904443f649928aa9cfd915cd938a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "80f977590ae94733a8a8552241c12e3b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "3ff5971c055144a3b81190d199ffe3de": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "be5e0fa21fea43e8bf003ae954c29d03": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_7535306bf05847629946e333021e0ef5",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_671f7a7556b1412cbd48237293431c0d",
+ "IPY_MODEL_8957849d6dbf44cfafa965d71de78255",
+ "IPY_MODEL_a55488797f71453e917032008f198b9c"
+ ]
+ }
+ },
+ "7535306bf05847629946e333021e0ef5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "671f7a7556b1412cbd48237293431c0d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_38ccc42c4ea14f0e83fca1cb9452bfad",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Iteration: 100%",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_ef2ddbabd7b042c0821cf999a9265867"
+ }
+ },
+ "8957849d6dbf44cfafa965d71de78255": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_0222497184b2446e90f801d84af22b82",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 5,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 5,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_6a70d59d1c834989a54deda9e776bf41"
+ }
+ },
+ "a55488797f71453e917032008f198b9c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_62e05dc21f5e438cb5e94e400071a39b",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 5/5 [00:02<00:00, 1.86it/s]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_803884fef29b433db14e61df7fae1ee7"
+ }
+ },
+ "38ccc42c4ea14f0e83fca1cb9452bfad": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "ef2ddbabd7b042c0821cf999a9265867": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "0222497184b2446e90f801d84af22b82": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "6a70d59d1c834989a54deda9e776bf41": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "62e05dc21f5e438cb5e94e400071a39b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "803884fef29b433db14e61df7fae1ee7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "83414a06fd504f71aa212d9fce15ffb5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_e2c8785b7c51448296c8cf54331f4a68",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_3f9881738b644024aea6371982320d97",
+ "IPY_MODEL_143666ff0c7b4e6491779b64f6212818",
+ "IPY_MODEL_bf699df8e3d844f68a68491c00e8f0bc"
+ ]
+ }
+ },
+ "e2c8785b7c51448296c8cf54331f4a68": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "3f9881738b644024aea6371982320d97": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_9e955ce77097447a8c085ea592ae8a5e",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Iteration: 100%",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_2f2aecb73861473ba553b4ebebd52e0b"
+ }
+ },
+ "143666ff0c7b4e6491779b64f6212818": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_db21f601fd6342a5815cea18c417aa99",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 5,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 5,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_67bf9d4f1a8e431bb580337be0e67f82"
+ }
+ },
+ "bf699df8e3d844f68a68491c00e8f0bc": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_ddcba8098de6433a8584045d52cb1f3b",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 5/5 [00:02<00:00, 1.89it/s]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_60062f7860944d0d85c4ef1773c151c3"
+ }
+ },
+ "9e955ce77097447a8c085ea592ae8a5e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "2f2aecb73861473ba553b4ebebd52e0b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "db21f601fd6342a5815cea18c417aa99": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "67bf9d4f1a8e431bb580337be0e67f82": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "ddcba8098de6433a8584045d52cb1f3b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "60062f7860944d0d85c4ef1773c151c3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "cc13e655b33d4fa390960d1fa40a0e1f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HBoxView",
+ "_dom_classes": [],
+ "_model_name": "HBoxModel",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "box_style": "",
+ "layout": "IPY_MODEL_2e418d21ae4f4123a9d7b13cbc368605",
+ "_model_module": "@jupyter-widgets/controls",
+ "children": [
+ "IPY_MODEL_ba1f476b8bdc4fce8a703a0220bd4770",
+ "IPY_MODEL_26560e743afa4a3fad0eb1e0ed567a64",
+ "IPY_MODEL_27e7a38811a94548b8dc1980e1c83acd"
+ ]
+ }
+ },
+ "2e418d21ae4f4123a9d7b13cbc368605": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "ba1f476b8bdc4fce8a703a0220bd4770": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_c5ceabb016a74435be6659f1116c9945",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": "Evaluating: ",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_4ca6ee080dad41aa8abbbea5a96e3922"
+ }
+ },
+ "26560e743afa4a3fad0eb1e0ed567a64": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "ProgressView",
+ "style": "IPY_MODEL_c0fd4025a3e84d0ca8c360887d7126ba",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "FloatProgressModel",
+ "bar_style": "success",
+ "max": 1,
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": 0,
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "orientation": "horizontal",
+ "min": 0,
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_4a374856a56c4dc6a163ee5779d6b666"
+ }
+ },
+ "27e7a38811a94548b8dc1980e1c83acd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "HTMLView",
+ "style": "IPY_MODEL_ef067c6b95ac48c58428f62ecef22e33",
+ "_dom_classes": [],
+ "description": "",
+ "_model_name": "HTMLModel",
+ "placeholder": "",
+ "_view_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "value": " 0/0 [00:00<?, ?it/s]",
+ "_view_count": null,
+ "_view_module_version": "1.5.0",
+ "description_tooltip": null,
+ "_model_module": "@jupyter-widgets/controls",
+ "layout": "IPY_MODEL_88521418122646b0b6c7d41be73e747a"
+ }
+ },
+ "c5ceabb016a74435be6659f1116c9945": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "4ca6ee080dad41aa8abbbea5a96e3922": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "c0fd4025a3e84d0ca8c360887d7126ba": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "ProgressStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "bar_color": null,
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "4a374856a56c4dc6a163ee5779d6b666": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": "20px",
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ },
+ "ef067c6b95ac48c58428f62ecef22e33": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_view_name": "StyleView",
+ "_model_name": "DescriptionStyleModel",
+ "description_width": "",
+ "_view_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.5.0",
+ "_view_count": null,
+ "_view_module_version": "1.2.0",
+ "_model_module": "@jupyter-widgets/controls"
+ }
+ },
+ "88521418122646b0b6c7d41be73e747a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_view_name": "LayoutView",
+ "grid_template_rows": null,
+ "right": null,
+ "justify_content": null,
+ "_view_module": "@jupyter-widgets/base",
+ "overflow": null,
+ "_model_module_version": "1.2.0",
+ "_view_count": null,
+ "flex_flow": null,
+ "width": null,
+ "min_width": null,
+ "border": null,
+ "align_items": null,
+ "bottom": null,
+ "_model_module": "@jupyter-widgets/base",
+ "top": null,
+ "grid_column": null,
+ "overflow_y": null,
+ "overflow_x": null,
+ "grid_auto_flow": null,
+ "grid_area": null,
+ "grid_template_columns": null,
+ "flex": null,
+ "_model_name": "LayoutModel",
+ "justify_items": null,
+ "grid_row": null,
+ "max_height": null,
+ "align_content": null,
+ "visibility": null,
+ "align_self": null,
+ "height": null,
+ "min_height": null,
+ "padding": null,
+ "grid_auto_rows": null,
+ "grid_gap": null,
+ "max_width": null,
+ "order": null,
+ "_view_module_version": "1.2.0",
+ "grid_template_areas": null,
+ "object_position": null,
+ "object_fit": null,
+ "grid_auto_columns": null,
+ "margin": null,
+ "display": null,
+ "left": null
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tSIO1yDEJbxI",
+ "outputId": "43bc1501-529c-48bc-d825-08c242d5de04"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
+ ]
+ }
+ ],
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount(\"/content/drive\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip -q install transformers"
+ ],
+ "metadata": {
+ "id": "LwrtmgMvMSey"
+ },
+ "execution_count": 58,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "os.chdir(\"/content/drive/My Drive/Colab Notebooks\")"
+ ],
+ "metadata": {
+ "id": "Mp864lxgIbJE"
+ },
+ "execution_count": 59,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# libraries\n",
+ "\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from __future__ import division\n",
+ "\n",
+ "import random\n",
+ "import glob\n",
+ "import logging\n",
+ "import os\n",
+ "import pickle\n",
+ "import re\n",
+ "import shutil\n",
+ "from typing import List, Dict, Tuple\n",
+ "\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "\n",
+ "from torch.nn.utils.rnn import pad_sequence\n",
+ "from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n",
+ "from torch.utils.data.distributed import DistributedSampler\n",
+ "from tqdm.notebook import tqdm, trange\n",
+ "\n",
+ "from pathlib import Path\n",
+ "\n",
+ "from transformers import (\n",
+ " MODEL_WITH_LM_HEAD_MAPPING,\n",
+ " WEIGHTS_NAME,\n",
+ " AdamW,\n",
+ " AutoConfig,\n",
+ " PreTrainedModel,\n",
+ " PreTrainedTokenizer,\n",
+ " get_linear_schedule_with_warmup,\n",
+ ")\n",
+ "\n",
+ "try:\n",
+ " from torch.utils.tensorboard import SummaryWriter\n",
+ "except ImportError:\n",
+ " from tensorboardX import SummaryWriter"
+ ],
+ "metadata": {
+ "id": "ujmUewQ5NVoO"
+ },
+ "execution_count": 60,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# visualize raw data\n",
+ "d = pd.read_csv(\"/content/drive/MyDrive/final-all-scripts.csv\", sep=\"delimiter\", header=None)\n",
+ "d.head()"
+ ],
+ "metadata": {
+ "id": "8gMOER_tVuIr",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 261
+ },
+ "outputId": "f8275436-770a-424d-fb41-a296ddc45045"
+ },
+ "execution_count": 61,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n",
+ " \n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [Class room] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " (A city is flying through space, stuck on the ... | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " COMPUTER: Well done, Mabel. Well done, Alfie. ... | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " (It is the little boy's turn.) | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " COMPUTER: Bad boy, Timmy. | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " 0\n",
+ "0 [Class room]\n",
+ "1 (A city is flying through space, stuck on the ...\n",
+ "2 COMPUTER: Well done, Mabel. Well done, Alfie. ...\n",
+ "3 (It is the little boy's turn.)\n",
+ "4 COMPUTER: Bad boy, Timmy."
+ ]
+ },
+ "metadata": {},
+ "execution_count": 61
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Data Preprocessing"
+ ],
+ "metadata": {
+ "id": "Vr2Y_QbooJUM"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(f\"Data type of file: {type(d)}\",\n",
+ " f\"\\nShape of file: {d.shape}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_mKAYMdxj2-l",
+ "outputId": "5bb10015-4046-41c8-e774-ed42e87ccfc7"
+ },
+ "execution_count": 62,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Data type of file: \n",
+ "Shape of file: (27597, 1)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(f\"Type of first element: {type(d.iloc[0])}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YiVX_L8P0P1K",
+ "outputId": "e280fb35-b11d-42d8-8ba3-50843120cea9"
+ },
+ "execution_count": 63,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Type of first element: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dd = []\n",
+ "\n",
+ "for i in d[0]:\n",
+ " if not (i.startswith(\"(\") or i.startswith(\"[\")):\n",
+ " dd.append(i)\n",
+ "\n",
+ "dd[:10]"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "y019y7pT02GL",
+ "outputId": "30c9a13b-3a42-4ac2-871b-9f830d6aa5a4"
+ },
+ "execution_count": 64,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "['COMPUTER: Well done, Mabel. Well done, Alfie. Good girl, Tabitha. Very well done, Ranjit. Good girl, Chloe. Well done, Ben. Well done, Mandy.',\n",
+ " 'COMPUTER: Bad boy, Timmy.',\n",
+ " 'COMPUTER: Zero.',\n",
+ " \"MANDY: You got a zero, didn't you?\",\n",
+ " 'TIMMY: Yeah? So?',\n",
+ " \"MANDY: You'll have to walk home then.\",\n",
+ " \"TIMMY: Walk to London? That's twenty decks!\",\n",
+ " \"MANDY: You can't ride a Vator with a zero. You know what happens. You'll get sent below.\",\n",
+ " \"MANDY: I'll wait for you.\",\n",
+ " \"SMILER: Welcome to Vator Verse, sponsored by McLintock's Candy Burgers. TIMMY: London, please.\"]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 64
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# person-text split\n",
+ "#dd[1].split(\":\")\n",
+ "\n",
+ "# each dialogue\n",
+ "#dialogues[0][1]"
+ ],
+ "metadata": {
+ "id": "fsT-q0762yJ8"
+ },
+ "execution_count": 65,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dialogues = [l.split(\":\") for l in dd]\n",
+ "len(dialogues)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "lcA-TTbR64_q",
+ "outputId": "9b59a4f0-b3fb-4ff5-c2a3-059183309548"
+ },
+ "execution_count": 66,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "20594"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 66
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chars= []\n",
+ "txts = []\n",
+ "\n",
+ "for i in range(len(dialogues)):\n",
+ " chars.append(dialogues[i][0])\n",
+ " txts.append(dialogues[i][1])"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 200
+ },
+ "id": "FtSYUoGO7XLk",
+ "outputId": "78760c97-5c54-4b9b-b263-a9ee0809ca1f"
+ },
+ "execution_count": 67,
+ "outputs": [
+ {
+ "output_type": "error",
+ "ename": "IndexError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mchars\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtxts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m: list index out of range"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "len(chars)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "TWxj098VzBXF",
+ "outputId": "198df600-0cbf-431b-eede-58a40b62f108"
+ },
+ "execution_count": 68,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "104"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 68
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dialogues[len(dialogues)-1][1]"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "id": "l0PCKQI5LkuR",
+ "outputId": "fb6f3000-7378-4f30-bd3a-e193bce34ba8"
+ },
+ "execution_count": 69,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "' So, dresses, then.'"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 69
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#dialogues[len(dialogues)-1][1] == dialogues[-1][1]"
+ ],
+ "metadata": {
+ "id": "kFVPluE3-ojX"
+ },
+ "execution_count": 70,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#pd.isnull(dialogues).sum()"
+ ],
+ "metadata": {
+ "id": "oemthiCWInCq"
+ },
+ "execution_count": 71,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# seperate person-text and convert dataframe\n",
+ "df = pd.DataFrame(list(zip(chars, txts)), columns=[\"Character\", \"Text\"])\n",
+ "df.head()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 206
+ },
+ "id": "UqgM35hBrQp6",
+ "outputId": "346eab19-132c-433b-9a99-3e1012b83eac"
+ },
+ "execution_count": 72,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Character | \n",
+ " Text | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " COMPUTER | \n",
+ " Well done, Mabel. Well done, Alfie. Good girl... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " COMPUTER | \n",
+ " Bad boy, Timmy. | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " COMPUTER | \n",
+ " Zero. | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " MANDY | \n",
+ " You got a zero, didn't you? | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " TIMMY | \n",
+ " Yeah? So? | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " Character Text\n",
+ "0 COMPUTER Well done, Mabel. Well done, Alfie. Good girl...\n",
+ "1 COMPUTER Bad boy, Timmy.\n",
+ "2 COMPUTER Zero.\n",
+ "3 MANDY You got a zero, didn't you?\n",
+ "4 TIMMY Yeah? So?"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 72
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "CHARACTER_NAME = \"DOCTOR\""
+ ],
+ "metadata": {
+ "id": "FY639-Fi7WfF"
+ },
+ "execution_count": 73,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "contexted = []\n",
+ "\n",
+ "# context window of size 7\n",
+ "n = 7\n",
+ "\n",
+ "for i in df[df.Character == CHARACTER_NAME].index:\n",
+ " if i < n:\n",
+ " continue\n",
+ " row = []\n",
+ " prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n",
+ " for j in range(i, prev, -1):\n",
+ " row.append(df.Text[j])\n",
+ " contexted.append(row)\n",
+ "\n",
+ "columns = ['response', 'context'] \n",
+ "columns = columns + ['context/' + str(i) for i in range(n - 1)]\n",
+ "\n",
+ "df = pd.DataFrame.from_records(contexted, columns=columns)"
+ ],
+ "metadata": {
+ "id": "vSqHrtAOz_1j"
+ },
+ "execution_count": 74,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "df.sample(6)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 490
+ },
+ "id": "GEVvAnKN0xad",
+ "outputId": "cdd5a2d2-f465-4d2a-947a-7269447ede2e"
+ },
+ "execution_count": 75,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " response | \n",
+ " context | \n",
+ " context/0 | \n",
+ " context/1 | \n",
+ " context/2 | \n",
+ " context/3 | \n",
+ " context/4 | \n",
+ " context/5 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 8 | \n",
+ " An important thing. In fact, Thing One. We ar... | \n",
+ " A thing? | \n",
+ " Course we can. But first, there's a thing. | \n",
+ " Can we go out and see? | \n",
+ " Well, come on. I've found us a spaceship. Thi... | \n",
+ " Doctor! | \n",
+ " Isn't that amazing? | \n",
+ " Doctor? | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " Don't know. I think a lot. It's hard to keep ... | \n",
+ " Why did you just do that with the water? | \n",
+ " Sorry. Checking all the water in this area. T... | \n",
+ " What are you doing? | \n",
+ " Life on a giant starship. Back to basics. Bic... | \n",
+ " London Market is a crime-free zone. | \n",
+ " Now, come on, look around you. Actually look. | \n",
+ " Oh my God, I'm in my nightie. | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " Come on, Pond. | \n",
+ " My name is Amy Pond. When I was seven, I had ... | \n",
+ " Help! Help me! | \n",
+ " Though the man above might say hello, expect ... | \n",
+ " A horse and a man, above, below. One has a pl... | \n",
+ " Welcome to Vator Verse, sponsored by McLintoc... | \n",
+ " I'll wait for you. | \n",
+ " You can't ride a Vator with a zero. You know ... | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " What I always do. Stay out of trouble. Badly. | \n",
+ " What are you going to do? | \n",
+ " It's this or Leadworth. What do you think? Le... | \n",
+ " No, hang on. What do I do? I don't know what ... | \n",
+ " They're clean. Everything else here is all ba... | \n",
+ " But they're just things. | \n",
+ " Deck two oh seven. Apple Sesame block, dwelli... | \n",
+ " Where'd she go? | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " Come on, use your eyes. Notice everything. Wh... | \n",
+ " What's wrong? | \n",
+ " Oh, lovely. You're a cheery one. Never mind d... | \n",
+ " I'm in the future. Like hundreds of years in ... | \n",
+ " Welcome to London Market. You are being monit... | \n",
+ " Doctor? | \n",
+ " So we're like a wildlife documentary, yeah? B... | \n",
+ " Ooo, that's interesting. | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " Ooo, that's interesting. | \n",
+ " An important thing. In fact, Thing One. We ar... | \n",
+ " A thing? | \n",
+ " Course we can. But first, there's a thing. | \n",
+ " Can we go out and see? | \n",
+ " Well, come on. I've found us a spaceship. Thi... | \n",
+ " Doctor! | \n",
+ " Isn't that amazing? | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " response ... context/5\n",
+ "8 An important thing. In fact, Thing One. We ar... ... Doctor?\n",
+ "16 Don't know. I think a lot. It's hard to keep ... ... Oh my God, I'm in my nightie.\n",
+ "0 Come on, Pond. ... You can't ride a Vator with a zero. You know ...\n",
+ "23 What I always do. Stay out of trouble. Badly. ... Where'd she go?\n",
+ "11 Come on, use your eyes. Notice everything. Wh... ... Ooo, that's interesting.\n",
+ "9 Ooo, that's interesting. ... Isn't that amazing?\n",
+ "\n",
+ "[6 rows x 8 columns]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 75
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trn_df, val_df = train_test_split(df, test_size=0.1)"
+ ],
+ "metadata": {
+ "id": "nYM_4zKirQ5A"
+ },
+ "execution_count": 76,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trn_df.head()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 424
+ },
+ "id": "RKF8dGVxS61X",
+ "outputId": "46d99699-9659-4d69-fd54-16b290ec2491"
+ },
+ "execution_count": 77,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " response | \n",
+ " context | \n",
+ " context/0 | \n",
+ " context/1 | \n",
+ " context/2 | \n",
+ " context/3 | \n",
+ " context/4 | \n",
+ " context/5 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 7 | \n",
+ " Course we can. But first, there's a thing. | \n",
+ " Can we go out and see? | \n",
+ " Well, come on. I've found us a spaceship. Thi... | \n",
+ " Doctor! | \n",
+ " Isn't that amazing? | \n",
+ " Doctor? | \n",
+ " Migrating to the stars. | \n",
+ " Doctor? | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " There. | \n",
+ " Where? | \n",
+ " Don't know. I think a lot. It's hard to keep ... | \n",
+ " Why did you just do that with the water? | \n",
+ " Sorry. Checking all the water in this area. T... | \n",
+ " What are you doing? | \n",
+ " Life on a giant starship. Back to basics. Bic... | \n",
+ " London Market is a crime-free zone. | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Now do you believe me? | \n",
+ " And my imaginary friend came back. | \n",
+ " Come on, Pond. | \n",
+ " My name is Amy Pond. When I was seven, I had ... | \n",
+ " Help! Help me! | \n",
+ " Though the man above might say hello, expect ... | \n",
+ " A horse and a man, above, below. One has a pl... | \n",
+ " Welcome to Vator Verse, sponsored by McLintoc... | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " Deck two oh seven. Apple Sesame block, dwelli... | \n",
+ " Where'd she go? | \n",
+ " Hundreds of parents walking past who spot her... | \n",
+ " Are you a parent? | \n",
+ " Crying silently. I mean, children cry because... | \n",
+ " One little girl crying. So? | \n",
+ " I'll have a look on the monitors. | \n",
+ " Apparently. | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " Come on, Pond. | \n",
+ " My name is Amy Pond. When I was seven, I had ... | \n",
+ " Help! Help me! | \n",
+ " Though the man above might say hello, expect ... | \n",
+ " A horse and a man, above, below. One has a pl... | \n",
+ " Welcome to Vator Verse, sponsored by McLintoc... | \n",
+ " I'll wait for you. | \n",
+ " You can't ride a Vator with a zero. You know ... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " response ... context/5\n",
+ "7 Course we can. But first, there's a thing. ... Doctor?\n",
+ "17 There. ... London Market is a crime-free zone.\n",
+ "1 Now do you believe me? ... Welcome to Vator Verse, sponsored by McLintoc...\n",
+ "20 Deck two oh seven. Apple Sesame block, dwelli... ... Apparently.\n",
+ "0 Come on, Pond. ... You can't ride a Vator with a zero. You know ...\n",
+ "\n",
+ "[5 rows x 8 columns]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 77
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# create dataset suitable for our model\n",
+ "def construct_conv(row, tokenizer, eos = True):\n",
+ " flatten = lambda l: [item for sublist in l for item in sublist]\n",
+ " conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n",
+ " conv = flatten(conv)\n",
+ " return conv\n",
+ "\n",
+ "class ConversationDataset(Dataset):\n",
+ " def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n",
+ "\n",
+ " block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n",
+ "\n",
+ " directory = args.cache_dir\n",
+ " cached_features_file = os.path.join(\n",
+ " directory, args.model_type + \"_cached_lm_\" + str(block_size)\n",
+ " )\n",
+ "\n",
+ " if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
+ " logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
+ " with open(cached_features_file, \"rb\") as handle:\n",
+ " self.examples = pickle.load(handle)\n",
+ " else:\n",
+ " logger.info(\"Creating features from dataset file at %s\", directory)\n",
+ "\n",
+ " self.examples = []\n",
+ " for _, row in df.iterrows():\n",
+ " conv = construct_conv(row, tokenizer)\n",
+ " self.examples.append(conv)\n",
+ "\n",
+ " logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
+ " with open(cached_features_file, \"wb\") as handle:\n",
+ " pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.examples)\n",
+ "\n",
+ " def __getitem__(self, item):\n",
+ " return torch.tensor(self.examples[item], dtype=torch.long)"
+ ],
+ "metadata": {
+ "id": "va9Olm-DoR9w"
+ },
+ "execution_count": 78,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Cacheing and storing of data/checkpoints\n",
+ "\n",
+ "def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n",
+ " return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n",
+ "\n",
+ "\n",
+ "def set_seed(args):\n",
+ " random.seed(args.seed)\n",
+ " np.random.seed(args.seed)\n",
+ " torch.manual_seed(args.seed)\n",
+ " if args.n_gpu > 0:\n",
+ " torch.cuda.manual_seed_all(args.seed)\n",
+ "\n",
+ "\n",
+ "def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n",
+ " ordering_and_checkpoint_path = []\n",
+ "\n",
+ " glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n",
+ "\n",
+ " for path in glob_checkpoints:\n",
+ " if use_mtime:\n",
+ " ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n",
+ " else:\n",
+ " regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n",
+ " if regex_match and regex_match.groups():\n",
+ " ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n",
+ "\n",
+ " checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n",
+ " checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n",
+ " return checkpoints_sorted\n",
+ "\n",
+ "\n",
+ "def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n",
+ " if not args.save_total_limit:\n",
+ " return\n",
+ " if args.save_total_limit <= 0:\n",
+ " return\n",
+ "\n",
+ " # Check if we should delete older checkpoint(s)\n",
+ " checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n",
+ " if len(checkpoints_sorted) <= args.save_total_limit:\n",
+ " return\n",
+ "\n",
+ " number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n",
+ " checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n",
+ " for checkpoint in checkpoints_to_be_deleted:\n",
+ " logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n",
+ " shutil.rmtree(checkpoint)"
+ ],
+ "metadata": {
+ "id": "wj1yakcqTCNx"
+ },
+ "execution_count": 79,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Build Model"
+ ],
+ "metadata": {
+ "id": "PNcEUFjSoML0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n",
+ "import torch\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n",
+ "model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UTI0g8PPg628",
+ "outputId": "8b8c6e19-fa17-42be-c9e4-25e5ae4c1819"
+ },
+ "execution_count": 80,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
+ " FutureWarning,\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# configs\n",
+ "\n",
+ "logger = logging.getLogger(__name__)\n",
+ "\n",
+ "MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n",
+ "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"
+ ],
+ "metadata": {
+ "id": "lzDAg6-eg7Fj"
+ },
+ "execution_count": 81,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Args to allow for easy convertion of python script to notebook\n",
+ "class Args():\n",
+ " def __init__(self):\n",
+ " self.output_dir = 'output-small'\n",
+ " self.model_type = 'gpt2'\n",
+ " self.model_name_or_path = 'microsoft/DialoGPT-small'\n",
+ " self.config_name = 'microsoft/DialoGPT-small'\n",
+ " self.tokenizer_name = 'microsoft/DialoGPT-small'\n",
+ " self.cache_dir = 'cached'\n",
+ " self.block_size = 512\n",
+ " self.do_train = True\n",
+ " self.do_eval = True\n",
+ " self.evaluate_during_training = False\n",
+ " self.per_gpu_train_batch_size = 4\n",
+ " self.per_gpu_eval_batch_size = 4\n",
+ " self.gradient_accumulation_steps = 1\n",
+ " self.learning_rate = 5e-5\n",
+ " self.weight_decay = 0.0\n",
+ " self.adam_epsilon = 1e-8\n",
+ " self.max_grad_norm = 1.0\n",
+ " self.num_train_epochs = 4\n",
+ " self.max_steps = -1\n",
+ " self.warmup_steps = 0\n",
+ " self.logging_steps = 1000\n",
+ " self.save_steps = 3500\n",
+ " self.save_total_limit = None\n",
+ " self.eval_all_checkpoints = False\n",
+ " self.no_cuda = False\n",
+ " self.overwrite_output_dir = True\n",
+ " self.overwrite_cache = True\n",
+ " self.should_continue = False\n",
+ " self.seed = 42\n",
+ " self.local_rank = -1\n",
+ " self.fp16 = False\n",
+ " self.fp16_opt_level = 'O1'\n",
+ "\n",
+ "args = Args()"
+ ],
+ "metadata": {
+ "id": "8b20p10Xg7M-"
+ },
+ "execution_count": 82,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Train and Evaluate"
+ ],
+ "metadata": {
+ "id": "9QaybLujoTg-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n",
+ " \"\"\" Train the model \"\"\"\n",
+ " if args.local_rank in [-1, 0]:\n",
+ " tb_writer = SummaryWriter()\n",
+ "\n",
+ " args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
+ "\n",
+ " def collate(examples: List[torch.Tensor]):\n",
+ " if tokenizer._pad_token is None:\n",
+ " return pad_sequence(examples, batch_first=True)\n",
+ " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
+ "\n",
+ " train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n",
+ " train_dataloader = DataLoader(\n",
+ " train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n",
+ " )\n",
+ "\n",
+ " if args.max_steps > 0:\n",
+ " t_total = args.max_steps\n",
+ " args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n",
+ " else:\n",
+ " t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n",
+ "\n",
+ " model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n",
+ " model.resize_token_embeddings(len(tokenizer))\n",
+ " # add_special_tokens_(model, tokenizer)\n",
+ "\n",
+ "\n",
+ " # Prepare optimizer and schedule (linear warmup and decay)\n",
+ " no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
+ " optimizer_grouped_parameters = [\n",
+ " {\n",
+ " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
+ " \"weight_decay\": args.weight_decay,\n",
+ " },\n",
+ " {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n",
+ " ]\n",
+ " optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
+ " scheduler = get_linear_schedule_with_warmup(\n",
+ " optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n",
+ " )\n",
+ "\n",
+ " # Check if saved optimizer or scheduler states exist\n",
+ " if (\n",
+ " args.model_name_or_path\n",
+ " and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n",
+ " and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n",
+ " ):\n",
+ " # Load in optimizer and scheduler states\n",
+ " optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n",
+ " scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n",
+ "\n",
+ " if args.fp16:\n",
+ " try:\n",
+ " from apex import amp\n",
+ " except ImportError:\n",
+ " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
+ " model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n",
+ "\n",
+ " # multi-gpu training (should be after apex fp16 initialization)\n",
+ " if args.n_gpu > 1:\n",
+ " model = torch.nn.DataParallel(model)\n",
+ "\n",
+ " # Distributed training (should be after apex fp16 initialization)\n",
+ " if args.local_rank != -1:\n",
+ " model = torch.nn.parallel.DistributedDataParallel(\n",
+ " model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n",
+ " )\n",
+ "\n",
+ " # Train!\n",
+ " logger.info(\"***** Running training *****\")\n",
+ " logger.info(\" Num examples = %d\", len(train_dataset))\n",
+ " logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n",
+ " logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n",
+ " logger.info(\n",
+ " \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n",
+ " args.train_batch_size\n",
+ " * args.gradient_accumulation_steps\n",
+ " * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n",
+ " )\n",
+ " logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n",
+ " logger.info(\" Total optimization steps = %d\", t_total)\n",
+ "\n",
+ " global_step = 0\n",
+ " epochs_trained = 0\n",
+ " steps_trained_in_current_epoch = 0\n",
+ " # Check if continuing training from a checkpoint\n",
+ " if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n",
+ " try:\n",
+ " # set global_step to gobal_step of last saved checkpoint from model path\n",
+ " checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n",
+ " global_step = int(checkpoint_suffix)\n",
+ " epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n",
+ " steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n",
+ "\n",
+ " logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n",
+ " logger.info(\" Continuing training from epoch %d\", epochs_trained)\n",
+ " logger.info(\" Continuing training from global step %d\", global_step)\n",
+ " logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n",
+ " except ValueError:\n",
+ " logger.info(\" Starting fine-tuning.\")\n",
+ "\n",
+ " tr_loss, logging_loss = 0.0, 0.0\n",
+ "\n",
+ " model.zero_grad()\n",
+ " train_iterator = trange(\n",
+ " epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n",
+ " )\n",
+ " set_seed(args) # Added here for reproducibility\n",
+ " for _ in train_iterator:\n",
+ " epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n",
+ " for step, batch in enumerate(epoch_iterator):\n",
+ "\n",
+ " # Skip past any already trained steps if resuming training\n",
+ " if steps_trained_in_current_epoch > 0:\n",
+ " steps_trained_in_current_epoch -= 1\n",
+ " continue\n",
+ "\n",
+ " inputs, labels = (batch, batch)\n",
+ " if inputs.shape[1] > 1024: continue\n",
+ " inputs = inputs.to(args.device)\n",
+ " labels = labels.to(args.device)\n",
+ " model.train()\n",
+ " outputs = model(inputs, labels=labels)\n",
+ " loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n",
+ "\n",
+ " if args.n_gpu > 1:\n",
+ " loss = loss.mean() # mean() to average on multi-gpu parallel training\n",
+ " if args.gradient_accumulation_steps > 1:\n",
+ " loss = loss / args.gradient_accumulation_steps\n",
+ "\n",
+ " if args.fp16:\n",
+ " with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
+ " scaled_loss.backward()\n",
+ " else:\n",
+ " loss.backward()\n",
+ "\n",
+ " tr_loss += loss.item()\n",
+ " if (step + 1) % args.gradient_accumulation_steps == 0:\n",
+ " if args.fp16:\n",
+ " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
+ " else:\n",
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n",
+ " optimizer.step()\n",
+ " scheduler.step() # Update learning rate schedule\n",
+ " model.zero_grad()\n",
+ " global_step += 1\n",
+ "\n",
+ " if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n",
+ " # Log metrics\n",
+ " if (\n",
+ " args.local_rank == -1 and args.evaluate_during_training\n",
+ " ): # Only evaluate when single GPU otherwise metrics may not average well\n",
+ " results = evaluate(args, model, tokenizer)\n",
+ " for key, value in results.items():\n",
+ " tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n",
+ " tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n",
+ " tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n",
+ " logging_loss = tr_loss\n",
+ "\n",
+ " if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n",
+ " checkpoint_prefix = \"checkpoint\"\n",
+ " # Save model checkpoint\n",
+ " output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n",
+ " os.makedirs(output_dir, exist_ok=True)\n",
+ " model_to_save = (\n",
+ " model.module if hasattr(model, \"module\") else model\n",
+ " ) # Take care of distributed/parallel training\n",
+ " model_to_save.save_pretrained(output_dir)\n",
+ " tokenizer.save_pretrained(output_dir)\n",
+ "\n",
+ " torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n",
+ " logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
+ "\n",
+ " _rotate_checkpoints(args, checkpoint_prefix)\n",
+ "\n",
+ " torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n",
+ " torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n",
+ " logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n",
+ "\n",
+ " if args.max_steps > 0 and global_step > args.max_steps:\n",
+ " epoch_iterator.close()\n",
+ " break\n",
+ " if args.max_steps > 0 and global_step > args.max_steps:\n",
+ " train_iterator.close()\n",
+ " break\n",
+ "\n",
+ " if args.local_rank in [-1, 0]:\n",
+ " tb_writer.close()\n",
+ "\n",
+ " return global_step, tr_loss / global_step\n",
+ "\n",
+ "# Evaluation of some model\n",
+ "\n",
+ "def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n",
+ " # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
+ " eval_output_dir = args.output_dir\n",
+ "\n",
+ " eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n",
+ " os.makedirs(eval_output_dir, exist_ok=True)\n",
+ " args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n",
+ " # Note that DistributedSampler samples randomly\n",
+ "\n",
+ " def collate(examples: List[torch.Tensor]):\n",
+ " if tokenizer._pad_token is None:\n",
+ " return pad_sequence(examples, batch_first=True)\n",
+ " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
+ "\n",
+ " eval_sampler = SequentialSampler(eval_dataset)\n",
+ " eval_dataloader = DataLoader(\n",
+ " eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n",
+ " )\n",
+ "\n",
+ " # multi-gpu evaluate\n",
+ " if args.n_gpu > 1:\n",
+ " model = torch.nn.DataParallel(model)\n",
+ "\n",
+ " # Eval!\n",
+ " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
+ " logger.info(\" Num examples = %d\", len(eval_dataset))\n",
+ " logger.info(\" Batch size = %d\", args.eval_batch_size)\n",
+ " eval_loss = 0.0\n",
+ " nb_eval_steps = 0\n",
+ " model.eval()\n",
+ "\n",
+ " for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
+ " inputs, labels = (batch, batch)\n",
+ " inputs = inputs.to(args.device)\n",
+ " labels = labels.to(args.device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " outputs = model(inputs, labels=labels)\n",
+ " lm_loss = outputs[0]\n",
+ " eval_loss += lm_loss.mean().item()\n",
+ " nb_eval_steps += 1\n",
+ "\n",
+ " eval_loss = eval_loss / (nb_eval_steps + 0.00001)\n",
+ " perplexity = torch.exp(torch.tensor(eval_loss))\n",
+ "\n",
+ " result = {\"perplexity\": perplexity}\n",
+ "\n",
+ " output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n",
+ " with open(output_eval_file, \"w\") as writer:\n",
+ " logger.info(\"***** Eval results {} *****\".format(prefix))\n",
+ " for key in sorted(result.keys()):\n",
+ " logger.info(\" %s = %s\", key, str(result[key]))\n",
+ " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
+ "\n",
+ " return result"
+ ],
+ "metadata": {
+ "id": "Yd7cAl8-oVSR"
+ },
+ "execution_count": 83,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Main runner\n",
+ "\n",
+ "def main(df_trn, df_val):\n",
+ " args = Args()\n",
+ " \n",
+ " if args.should_continue:\n",
+ " sorted_checkpoints = _sorted_checkpoints(args)\n",
+ " if len(sorted_checkpoints) == 0:\n",
+ " raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n",
+ " else:\n",
+ " args.model_name_or_path = sorted_checkpoints[-1]\n",
+ "\n",
+ " if (\n",
+ " os.path.exists(args.output_dir)\n",
+ " and os.listdir(args.output_dir)\n",
+ " and args.do_train\n",
+ " and not args.overwrite_output_dir\n",
+ " and not args.should_continue\n",
+ " ):\n",
+ " raise ValueError(\n",
+ " \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n",
+ " args.output_dir\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " # Setup CUDA, GPU & distributed training\n",
+ " device = torch.device(\"cuda\")\n",
+ " args.n_gpu = torch.cuda.device_count()\n",
+ " args.device = device\n",
+ "\n",
+ " # Setup logging\n",
+ " logging.basicConfig(\n",
+ " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
+ " datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
+ " level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n",
+ " )\n",
+ " logger.warning(\n",
+ " \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n",
+ " args.local_rank,\n",
+ " device,\n",
+ " args.n_gpu,\n",
+ " bool(args.local_rank != -1),\n",
+ " args.fp16,\n",
+ " )\n",
+ "\n",
+ " # Set seed\n",
+ " set_seed(args)\n",
+ "\n",
+ " config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n",
+ " tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n",
+ " model = AutoModelWithLMHead.from_pretrained(\n",
+ " args.model_name_or_path,\n",
+ " from_tf=False,\n",
+ " config=config,\n",
+ " cache_dir=args.cache_dir,\n",
+ " )\n",
+ " model.to(args.device)\n",
+ " \n",
+ " logger.info(\"Training/evaluation parameters %s\", args)\n",
+ "\n",
+ " # Training\n",
+ " if args.do_train:\n",
+ " train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n",
+ "\n",
+ " global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n",
+ " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n",
+ "\n",
+ " # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n",
+ " if args.do_train:\n",
+ " # Create output directory if needed\n",
+ " os.makedirs(args.output_dir, exist_ok=True)\n",
+ "\n",
+ " logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n",
+ " # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
+ " # They can then be reloaded using `from_pretrained()`\n",
+ " model_to_save = (\n",
+ " model.module if hasattr(model, \"module\") else model\n",
+ " ) # Take care of distributed/parallel training\n",
+ " model_to_save.save_pretrained(args.output_dir)\n",
+ " tokenizer.save_pretrained(args.output_dir)\n",
+ "\n",
+ " # Good practice: save your training arguments together with the trained model\n",
+ " torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n",
+ "\n",
+ " # Load a trained model and vocabulary that you have fine-tuned\n",
+ " model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n",
+ " tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n",
+ " model.to(args.device)\n",
+ "\n",
+ " # Evaluation\n",
+ " results = {}\n",
+ " if args.do_eval and args.local_rank in [-1, 0]:\n",
+ " checkpoints = [args.output_dir]\n",
+ " if args.eval_all_checkpoints:\n",
+ " checkpoints = list(\n",
+ " os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n",
+ " )\n",
+ " logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
+ " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
+ " for checkpoint in checkpoints:\n",
+ " global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n",
+ " prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n",
+ "\n",
+ " model = AutoModelWithLMHead.from_pretrained(checkpoint)\n",
+ " model.to(args.device)\n",
+ " result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n",
+ " result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n",
+ " results.update(result)\n",
+ "\n",
+ " return results"
+ ],
+ "metadata": {
+ "id": "M93fjuFwiu-T"
+ },
+ "execution_count": 84,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run The Main Function"
+ ],
+ "metadata": {
+ "id": "jdBULkbmoX6E"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "main(trn_df, val_df)"
+ ],
+ "metadata": {
+ "id": "IhQ-I1_Vobx0",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 646,
+ "referenced_widgets": [
+ "a49e5fd0d85444a3aa9f786455ca8770",
+ "73e8d052a86647919649a367aa773c8e",
+ "40760124752846209e61177280a005bd",
+ "eae3f41495884830818311e51920c956",
+ "5d1a116b987549d780ee25723f83d45a",
+ "46f7e33281354ef488945f5f1cfe4c06",
+ "e987dfed8c624717b5ae2054cce74f05",
+ "c49e73bdfe544de0ae62034cef7eb0da",
+ "52e599d90ccd44938d982310fb7e4341",
+ "c01768976adf465ebfad5c3eedfe1d58",
+ "27fb7c7e261b4a3b9656a37b1fcde71a",
+ "1c82670ef31346eb97dff63429fd522f",
+ "a8c2fda5e0be4c638919b4ca1007dea3",
+ "5e60bbde81ed452fa0c8d7094d98b052",
+ "3ac055433ca94c2ebe9f8b44e38be5e0",
+ "67e70ffbe152488fb036968be105a368",
+ "bbeb9a01f5bb4aebba239db555f4b16b",
+ "771602cc4d9444e7ab0d20438639cddd",
+ "f154ee2be8a044b3aeeb0e904411ffbd",
+ "a3a6841089054f1cbc31f638424674b3",
+ "41061818a9c94956a7d1cd129028d805",
+ "bab64ef3864248018e9476bc8c4018f4",
+ "f3fa20cd1c40453bb17b2f109607e1bf",
+ "97b3a5270a014515bbc712b44dba38a0",
+ "afe3a3438fb145d4a015fdb0709e3156",
+ "6d6078316fe54c9e83a3c3a35a1169fc",
+ "4700a281e7d347db8a58c6f181706b54",
+ "84f8bdfeb6bf4bb7ba4585eba47a7092",
+ "9262f880b64a4abb80013f6997901bcb",
+ "386289f0bf56453484a6637d3263da4c",
+ "8366904443f649928aa9cfd915cd938a",
+ "80f977590ae94733a8a8552241c12e3b",
+ "3ff5971c055144a3b81190d199ffe3de",
+ "be5e0fa21fea43e8bf003ae954c29d03",
+ "7535306bf05847629946e333021e0ef5",
+ "671f7a7556b1412cbd48237293431c0d",
+ "8957849d6dbf44cfafa965d71de78255",
+ "a55488797f71453e917032008f198b9c",
+ "38ccc42c4ea14f0e83fca1cb9452bfad",
+ "ef2ddbabd7b042c0821cf999a9265867",
+ "0222497184b2446e90f801d84af22b82",
+ "6a70d59d1c834989a54deda9e776bf41",
+ "62e05dc21f5e438cb5e94e400071a39b",
+ "803884fef29b433db14e61df7fae1ee7",
+ "83414a06fd504f71aa212d9fce15ffb5",
+ "e2c8785b7c51448296c8cf54331f4a68",
+ "3f9881738b644024aea6371982320d97",
+ "143666ff0c7b4e6491779b64f6212818",
+ "bf699df8e3d844f68a68491c00e8f0bc",
+ "9e955ce77097447a8c085ea592ae8a5e",
+ "2f2aecb73861473ba553b4ebebd52e0b",
+ "db21f601fd6342a5815cea18c417aa99",
+ "67bf9d4f1a8e431bb580337be0e67f82",
+ "ddcba8098de6433a8584045d52cb1f3b",
+ "60062f7860944d0d85c4ef1773c151c3",
+ "cc13e655b33d4fa390960d1fa40a0e1f",
+ "2e418d21ae4f4123a9d7b13cbc368605",
+ "ba1f476b8bdc4fce8a703a0220bd4770",
+ "26560e743afa4a3fad0eb1e0ed567a64",
+ "27e7a38811a94548b8dc1980e1c83acd",
+ "c5ceabb016a74435be6659f1116c9945",
+ "4ca6ee080dad41aa8abbbea5a96e3922",
+ "c0fd4025a3e84d0ca8c360887d7126ba",
+ "4a374856a56c4dc6a163ee5779d6b666",
+ "ef067c6b95ac48c58428f62ecef22e33",
+ "88521418122646b0b6c7d41be73e747a"
+ ]
+ },
+ "outputId": "612cca73-0fa4-41c0-f2a3-ba9df82d4b4c"
+ },
+ "execution_count": 85,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "01/24/2022 12:02:55 - WARNING - __main__ - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False\n",
+ "/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
+ " FutureWarning,\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Training/evaluation parameters <__main__.Args object at 0x7f0555d60550>\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Creating features from dataset file at cached\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Saving features into cached file cached/gpt2_cached_lm_512\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - ***** Running training *****\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Num examples = 22\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Num Epochs = 4\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Instantaneous batch size per GPU = 4\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 4\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Gradient Accumulation steps = 1\n",
+ "01/24/2022 12:03:00 - INFO - __main__ - Total optimization steps = 20\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a49e5fd0d85444a3aa9f786455ca8770",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Epoch: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1c82670ef31346eb97dff63429fd522f",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Iteration: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f3fa20cd1c40453bb17b2f109607e1bf",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Iteration: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "be5e0fa21fea43e8bf003ae954c29d03",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Iteration: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "83414a06fd504f71aa212d9fce15ffb5",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Iteration: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "01/24/2022 12:03:12 - INFO - __main__ - global_step = 20, average loss = 4.869351613521576\n",
+ "01/24/2022 12:03:12 - INFO - __main__ - Saving model checkpoint to output-small\n",
+ "01/24/2022 12:03:17 - INFO - __main__ - Evaluate the following checkpoints: ['output-small']\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - Creating features from dataset file at cached\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - Saving features into cached file cached/gpt2_cached_lm_512\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - ***** Running evaluation *****\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - Num examples = 3\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - Batch size = 4\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cc13e655b33d4fa390960d1fa40a0e1f",
+ "version_minor": 0,
+ "version_major": 2
+ },
+ "text/plain": [
+ "Evaluating: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "01/24/2022 12:03:20 - INFO - __main__ - ***** Eval results *****\n",
+ "01/24/2022 12:03:20 - INFO - __main__ - perplexity = tensor(1.)\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "{'perplexity_': tensor(1.)}"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 85
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load The Trained Model"
+ ],
+ "metadata": {
+ "id": "F_xqB94xocSg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n",
+ "model = AutoModelWithLMHead.from_pretrained('output-small')"
+ ],
+ "metadata": {
+ "id": "9AGnm3bmofkU",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "c6949113-68cb-4373-f390-983ea871c7ed"
+ },
+ "execution_count": 86,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
+ " FutureWarning,\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Let's chat for 4 lines\n",
+ "for step in range(4):\n",
+ " # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
+ " new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n",
+ " # print(new_user_input_ids)\n",
+ "\n",
+ " # append the new user input tokens to the chat history\n",
+ " bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
+ "\n",
+ " # generated a response while limiting the total chat history to 1000 tokens, \n",
+ " chat_history_ids = model.generate(\n",
+ " bot_input_ids, max_length=200,\n",
+ " pad_token_id=tokenizer.eos_token_id, \n",
+ " no_repeat_ngram_size=3, \n",
+ " do_sample=True, \n",
+ " top_k=100, \n",
+ " top_p=0.7,\n",
+ " temperature=0.8\n",
+ " )\n",
+ " \n",
+ " # pretty print last ouput tokens from bot\n",
+ " print(\"EleventhDoctorBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
+ ],
+ "metadata": {
+ "id": "2-CQOEWPrmE7",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "fb711e82-492c-45a3-8aab-7e888bff12cb"
+ },
+ "execution_count": 90,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ ">> User:why are you not working\n",
+ "EleventhDoctorBot: I'm not a lawyer, but I'm going to go out on a limb and say that the defendant's name was already on the back of the police's desk.\n",
+ ">> User:i wish you were here\n",
+ "EleventhDoctorBot: I'll be back in a few hours to talk about this case and the police report.\n",
+ ">> User:what police\n",
+ "EleventhDoctorBot: I just got a new phone and it's my new phone so I'm not going to talk to anyone on it\n",
+ ">> User:haha\n",
+ "EleventhDoctorBot: I know, I just had a phone call and it was a phonecall that was being sent to the police and they were on it when they came back to\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Push Model to HuggingFace"
+ ],
+ "metadata": {
+ "id": "KQBRSKDcoiJ4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#model.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)\n",
+ "#tokenizer.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)"
+ ],
+ "metadata": {
+ "id": "E_IH5n-P2u3N"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "os.chdir(\"/content/\")"
+ ],
+ "metadata": {
+ "id": "tQtHvpnXA2fC"
+ },
+ "execution_count": 88,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "SJa0EMUZ-gYI"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file