Spaces:
Sleeping
Sleeping
File size: 57,326 Bytes
c046d7f |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNaEJSE86MIYL6MGO8lUr3+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"f77bbc2602d846f8bb1c9e06f7b519ef":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d2c87c66057f4e33bdbfe078ee47c0b2","IPY_MODEL_fe33fa70d2c540e9831d408ffb5d3af9","IPY_MODEL_61d7b89f706348548ebcaf0ced92a44e"],"layout":"IPY_MODEL_a51b1054233342feaef0b16d2627a658"}},"d2c87c66057f4e33bdbfe078ee47c0b2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c4aa2aa530034af487957db71e9c509f","placeholder":"β","style":"IPY_MODEL_b4957dccdc5c4b83982212497febf4dc","value":"100%"}},"fe33fa70d2c540e9831d408ffb5d3af9":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_592c8770c24c4366843336989c8cdb5f","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_f6c8ccfd74844b7ebafb37f181b51130","value":2}},"61d7b89f706348548ebcaf0ced92a44e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_178935e2880043dc9eca862c09c05c16","placeholder":"β","style":"IPY_MODEL_aac7d4c2eebf4eb1914d2d98c52243ce","value":" 2/2 [00:00<00:00, 46.37it/s]"}},"a51b1054233342feaef0b16d2627a658":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c4aa2aa530034af487957db71e9c509f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b4957dccdc5c4b83982212497febf4dc":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"592c8770c24c4366843336989c8cdb5f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f6c8ccfd74844b7ebafb37f181b51130":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"178935e2880043dc9eca862c09c05c16":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"aac7d4c2eebf4eb1914d2d98c52243ce":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["# Introduction\n","\n","This is a training script for a diffusion model called MinImagen. A smaller adaptation of original Imagen architecture introduced by Google."],"metadata":{"id":"yQMHUdQADLcT"}},{"cell_type":"markdown","source":["# Setup"],"metadata":{"id":"Qz5vkO4bEh9b"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"LeDWeQuVl_6s","executionInfo":{"status":"ok","timestamp":1690783982884,"user_tz":-330,"elapsed":117837,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"6fb30b50-90aa-4217-82e9-8fdea0826b9c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting minimagen\n"," Downloading minimagen-0.0.9-py3-none-any.whl (43 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m43.0/43.0 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiohttp==3.8.1 (from minimagen)\n"," Downloading aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m21.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal==1.2.0 (from minimagen)\n"," Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n","Requirement already satisfied: async-timeout==4.0.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (4.0.2)\n","Collecting attrs==21.4.0 (from minimagen)\n"," Downloading attrs-21.4.0-py2.py3-none-any.whl (60 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m60.6/60.6 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting certifi==2022.6.15 (from minimagen)\n"," Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m160.2/160.2 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting charset-normalizer==2.1.0 (from minimagen)\n"," Downloading charset_normalizer-2.1.0-py3-none-any.whl (39 kB)\n","Collecting colorama==0.4.5 (from minimagen)\n"," Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)\n","Collecting datasets==2.3.2 (from minimagen)\n"," Downloading datasets-2.3.2-py3-none-any.whl (362 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m362.3/362.3 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting dill==0.3.5.1 (from minimagen)\n"," Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting einops==0.4.1 (from minimagen)\n"," Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n","Collecting einops-exts==0.0.3 (from minimagen)\n"," Downloading einops_exts-0.0.3-py3-none-any.whl (3.8 kB)\n","Collecting filelock==3.7.1 (from minimagen)\n"," Downloading filelock-3.7.1-py3-none-any.whl (10 kB)\n","Collecting frozenlist==1.3.0 (from minimagen)\n"," Downloading frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (157 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m157.9/157.9 kB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting fsspec==2022.5.0 (from minimagen)\n"," Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m140.6/140.6 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting future==0.18.2 (from minimagen)\n"," Downloading future-0.18.2.tar.gz (829 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m829.2/829.2 kB\u001b[0m \u001b[31m71.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting huggingface-hub==0.8.1 (from minimagen)\n"," Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m101.5/101.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting idna==3.3 (from minimagen)\n"," Downloading idna-3.3-py3-none-any.whl (61 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m61.2/61.2 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multidict==6.0.2 (from minimagen)\n"," Downloading multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess==0.70.13 (from minimagen)\n"," Downloading multiprocess-0.70.13-py310-none-any.whl (133 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m133.1/133.1 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting numpy==1.23.1 (from minimagen)\n"," Downloading numpy-1.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.0 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m17.0/17.0 MB\u001b[0m \u001b[31m99.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting packaging==21.3 (from minimagen)\n"," Downloading packaging-21.3-py3-none-any.whl (40 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pandas==1.4.3 (from minimagen)\n"," Downloading pandas-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n","\u001b[2K \u001b[90mβββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m105.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting Pillow==9.2.0 (from minimagen)\n"," Downloading Pillow-9.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m3.2/3.2 MB\u001b[0m \u001b[31m80.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyarrow==8.0.0 (from minimagen)\n"," Downloading pyarrow-8.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m29.4/29.4 MB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyparsing==3.0.9 (from minimagen)\n"," Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m98.3/98.3 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: python-dateutil==2.8.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (2.8.2)\n","Collecting pytz==2022.1 (from minimagen)\n"," Downloading pytz-2022.1-py2.py3-none-any.whl (503 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m503.5/503.5 kB\u001b[0m \u001b[31m44.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting PyYAML==6.0 (from minimagen)\n"," Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m682.2/682.2 kB\u001b[0m \u001b[31m56.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting regex==2022.7.9 (from minimagen)\n"," Downloading regex-2022.7.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (764 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m764.0/764.0 kB\u001b[0m \u001b[31m60.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting requests==2.28.1 (from minimagen)\n"," Downloading requests-2.28.1-py3-none-any.whl (62 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting resize-right==0.0.2 (from minimagen)\n"," Downloading resize_right-0.0.2-py3-none-any.whl (8.9 kB)\n","Collecting responses==0.18.0 (from minimagen)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Collecting sentencepiece==0.1.96 (from minimagen)\n"," Downloading sentencepiece-0.1.96-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m82.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: six==1.16.0 in /usr/local/lib/python3.10/dist-packages (from minimagen) (1.16.0)\n","Collecting tdqm==0.0.1 (from minimagen)\n"," Downloading tdqm-0.0.1.tar.gz (1.4 kB)\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting tokenizers==0.12.1 (from minimagen)\n"," Downloading tokenizers-0.12.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch==1.12.0 (from minimagen)\n"," Downloading torch-1.12.0-cp310-cp310-manylinux1_x86_64.whl (776.3 MB)\n","\u001b[2K \u001b[90mβββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m776.3/776.3 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torchvision==0.13.0 (from minimagen)\n"," Downloading torchvision-0.13.0-cp310-cp310-manylinux1_x86_64.whl (19.1 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m19.1/19.1 MB\u001b[0m \u001b[31m70.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting tqdm==4.64.0 (from minimagen)\n"," Downloading tqdm-4.64.0-py2.py3-none-any.whl (78 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m78.4/78.4 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting transformers==4.20.1 (from minimagen)\n"," Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m118.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting typing-extensions==4.3.0 (from minimagen)\n"," Downloading typing_extensions-4.3.0-py3-none-any.whl (25 kB)\n","Collecting urllib3==1.26.10 (from minimagen)\n"," Downloading urllib3-1.26.10-py2.py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m139.2/139.2 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting xxhash==3.0.0 (from minimagen)\n"," Downloading xxhash-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (211 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m211.6/211.6 kB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting yarl==1.7.2 (from minimagen)\n"," Downloading yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (305 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m305.3/305.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.3.2->minimagen) (2023.6.0)\n","INFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n","Collecting fsspec[http]>=2021.05.0 (from datasets==2.3.2->minimagen)\n"," Downloading fsspec-2023.5.0-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m160.1/160.1 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.4.0-py3-none-any.whl (153 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m154.0/154.0 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.3.0-py3-none-any.whl (145 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m145.4/145.4 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.1.0-py3-none-any.whl (143 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m143.0/143.0 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.11.0-py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m139.5/139.5 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.10.0-py3-none-any.whl (138 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m138.8/138.8 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.8.2-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hINFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n"," Downloading fsspec-2022.7.1-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.7.0-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hBuilding wheels for collected packages: future, tdqm\n"," Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491057 sha256=7027523a9d22db99ee9aceb42c9fee72cfd0c8ed02e5907796d4270817864786\n"," Stored in directory: /root/.cache/pip/wheels/22/73/06/557dc4f4ef68179b9d763930d6eec26b88ed7c389b19588a1c\n"," Building wheel for tdqm (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for tdqm: filename=tdqm-0.0.1-py3-none-any.whl size=1321 sha256=78bbe236ae25f778666a9625686232a20f68f3d889132d56ea15c515fd88a311\n"," Stored in directory: /root/.cache/pip/wheels/37/31/b8/7b711038035720ba0df14376af06e5e76b9bd61759c861ad92\n","Successfully built future tdqm\n","Installing collected packages: tokenizers, sentencepiece, resize-right, pytz, einops, xxhash, urllib3, typing-extensions, tqdm, regex, PyYAML, pyparsing, Pillow, numpy, multidict, idna, future, fsspec, frozenlist, filelock, einops-exts, dill, colorama, charset-normalizer, certifi, attrs, yarl, torch, tdqm, requests, pyarrow, pandas, packaging, multiprocess, aiosignal, torchvision, responses, huggingface-hub, aiohttp, transformers, datasets, minimagen\n"," Attempting uninstall: pytz\n"," Found existing installation: pytz 2022.7.1\n"," Uninstalling pytz-2022.7.1:\n"," Successfully uninstalled pytz-2022.7.1\n"," Attempting uninstall: urllib3\n"," Found existing installation: urllib3 1.26.16\n"," Uninstalling urllib3-1.26.16:\n"," Successfully uninstalled urllib3-1.26.16\n"," Attempting uninstall: typing-extensions\n"," Found existing installation: typing_extensions 4.7.1\n"," Uninstalling typing_extensions-4.7.1:\n"," Successfully uninstalled typing_extensions-4.7.1\n"," Attempting uninstall: tqdm\n"," Found existing installation: tqdm 4.65.0\n"," Uninstalling tqdm-4.65.0:\n"," Successfully uninstalled tqdm-4.65.0\n"," Attempting uninstall: regex\n"," Found existing installation: regex 2022.10.31\n"," Uninstalling regex-2022.10.31:\n"," Successfully uninstalled regex-2022.10.31\n"," Attempting uninstall: PyYAML\n"," Found existing installation: PyYAML 6.0.1\n"," Uninstalling PyYAML-6.0.1:\n"," Successfully uninstalled PyYAML-6.0.1\n"," Attempting uninstall: pyparsing\n"," Found existing installation: pyparsing 3.1.0\n"," Uninstalling pyparsing-3.1.0:\n"," Successfully uninstalled pyparsing-3.1.0\n"," Attempting uninstall: Pillow\n"," Found existing installation: Pillow 9.4.0\n"," Uninstalling Pillow-9.4.0:\n"," Successfully uninstalled Pillow-9.4.0\n"," Attempting uninstall: numpy\n"," Found existing installation: numpy 1.22.4\n"," Uninstalling numpy-1.22.4:\n"," Successfully uninstalled numpy-1.22.4\n"," Attempting uninstall: multidict\n"," Found existing installation: multidict 6.0.4\n"," Uninstalling multidict-6.0.4:\n"," Successfully uninstalled multidict-6.0.4\n"," Attempting uninstall: idna\n"," Found existing installation: idna 3.4\n"," Uninstalling idna-3.4:\n"," Successfully uninstalled idna-3.4\n"," Attempting uninstall: future\n"," Found existing installation: future 0.18.3\n"," Uninstalling future-0.18.3:\n"," Successfully uninstalled future-0.18.3\n"," Attempting uninstall: fsspec\n"," Found existing installation: fsspec 2023.6.0\n"," Uninstalling fsspec-2023.6.0:\n"," Successfully uninstalled fsspec-2023.6.0\n"," Attempting uninstall: frozenlist\n"," Found existing installation: frozenlist 1.4.0\n"," Uninstalling frozenlist-1.4.0:\n"," Successfully uninstalled frozenlist-1.4.0\n"," Attempting uninstall: filelock\n"," Found existing installation: filelock 3.12.2\n"," Uninstalling filelock-3.12.2:\n"," Successfully uninstalled filelock-3.12.2\n"," Attempting uninstall: charset-normalizer\n"," Found existing installation: charset-normalizer 2.0.12\n"," Uninstalling charset-normalizer-2.0.12:\n"," Successfully uninstalled charset-normalizer-2.0.12\n"," Attempting uninstall: certifi\n"," Found existing installation: certifi 2023.7.22\n"," Uninstalling certifi-2023.7.22:\n"," Successfully uninstalled certifi-2023.7.22\n"," Attempting uninstall: attrs\n"," Found existing installation: attrs 23.1.0\n"," Uninstalling attrs-23.1.0:\n"," Successfully uninstalled attrs-23.1.0\n"," Attempting uninstall: yarl\n"," Found existing installation: yarl 1.9.2\n"," Uninstalling yarl-1.9.2:\n"," Successfully uninstalled yarl-1.9.2\n"," Attempting uninstall: torch\n"," Found existing installation: torch 2.0.1+cu118\n"," Uninstalling torch-2.0.1+cu118:\n"," Successfully uninstalled torch-2.0.1+cu118\n"," Attempting uninstall: requests\n"," Found existing installation: requests 2.27.1\n"," Uninstalling requests-2.27.1:\n"," Successfully uninstalled requests-2.27.1\n"," Attempting uninstall: pyarrow\n"," Found existing installation: pyarrow 9.0.0\n"," Uninstalling pyarrow-9.0.0:\n"," Successfully uninstalled pyarrow-9.0.0\n"," Attempting uninstall: pandas\n"," Found existing installation: pandas 1.5.3\n"," Uninstalling pandas-1.5.3:\n"," Successfully uninstalled pandas-1.5.3\n"," Attempting uninstall: packaging\n"," Found existing installation: packaging 23.1\n"," Uninstalling packaging-23.1:\n"," Successfully uninstalled packaging-23.1\n"," Attempting uninstall: aiosignal\n"," Found existing installation: aiosignal 1.3.1\n"," Uninstalling aiosignal-1.3.1:\n"," Successfully uninstalled aiosignal-1.3.1\n"," Attempting uninstall: torchvision\n"," Found existing installation: torchvision 0.15.2+cu118\n"," Uninstalling torchvision-0.15.2+cu118:\n"," Successfully uninstalled torchvision-0.15.2+cu118\n"," Attempting uninstall: aiohttp\n"," Found existing installation: aiohttp 3.8.5\n"," Uninstalling aiohttp-3.8.5:\n"," Successfully uninstalled aiohttp-3.8.5\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","gcsfs 2023.6.0 requires fsspec==2023.6.0, but you have fsspec 2022.5.0 which is incompatible.\n","google-colab 1.0.0 requires pandas==1.5.3, but you have pandas 1.4.3 which is incompatible.\n","google-colab 1.0.0 requires requests==2.27.1, but you have requests 2.28.1 which is incompatible.\n","torchaudio 2.0.2+cu118 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchdata 0.6.1 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchtext 0.15.2 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","yfinance 0.2.25 requires pytz>=2022.5, but you have pytz 2022.1 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed Pillow-9.2.0 PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 attrs-21.4.0 certifi-2022.6.15 charset-normalizer-2.1.0 colorama-0.4.5 datasets-2.3.2 dill-0.3.5.1 einops-0.4.1 einops-exts-0.0.3 filelock-3.7.1 frozenlist-1.3.0 fsspec-2022.5.0 future-0.18.2 huggingface-hub-0.8.1 idna-3.3 minimagen-0.0.9 multidict-6.0.2 multiprocess-0.70.13 numpy-1.23.1 packaging-21.3 pandas-1.4.3 pyarrow-8.0.0 pyparsing-3.0.9 pytz-2022.1 regex-2022.7.9 requests-2.28.1 resize-right-0.0.2 responses-0.18.0 sentencepiece-0.1.96 tdqm-0.0.1 tokenizers-0.12.1 torch-1.12.0 torchvision-0.13.0 tqdm-4.64.0 transformers-4.20.1 typing-extensions-4.3.0 urllib3-1.26.10 xxhash-3.0.0 yarl-1.7.2\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["PIL","certifi","numpy","packaging","torch","tqdm"]}}},"metadata":{}}],"source":["#install the minimagen package\n","!pip install minimagen"]},{"cell_type":"code","source":["#utility imports\n","import os\n","from datetime import datetime\n","\n","#pytorch related imports\n","import torch.utils.data as data_utils\n","from torch import optim\n","\n","#minimagen related imports\n","from minimagen.Imagen import Imagen\n","from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest\n","from minimagen.generate import load_minimagen, load_params\n","from minimagen.t5 import get_encoded_dim\n","from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \\\n"," create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \\\n"," load_testing_parameters"],"metadata":{"id":"b_4eGJywmHR5","executionInfo":{"status":"error","timestamp":1690968676849,"user_tz":-330,"elapsed":5122,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"154b276a-c7dd-4469-d5b8-8d1fc03b5159","colab":{"base_uri":"https://localhost:8080/","height":381}},"execution_count":null,"outputs":[{"output_type":"error","ename":"ModuleNotFoundError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-1-192156f481b5>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImagen\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImagen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mUnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBaseTest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuperTest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_minimagen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'minimagen'","","\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"],"errorDetails":{"actions":[{"action":"open_url","actionText":"Open Examples","url":"/notebooks/snippets/importing_libraries.ipynb"}]}}]},{"cell_type":"code","source":["# Get device: Connect to GPU runtime for better performance\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","# Command line argument parser\n","parser = get_minimagen_parser()\n","class args_cls:\n"," a = 0\n","\n","#get an instance of the args_cls\n","args = args_cls()"],"metadata":{"id":"GoNwdipqmH95"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#directory creation for training\n","timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n","dir_path = f\"./training_{timestamp}\"\n","training_dir = create_directory(dir_path)"],"metadata":{"id":"eq8I0I7MmKFz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#A dictionary of hyperparameters\n","hyperparameters = dict(\n"," PARAMETERS=None,\n"," NUM_WORKERS=0,\n"," BATCH_SIZE=20,\n"," MAX_NUM_WORDS=32,\n"," IMG_SIDE_LEN=128,\n"," EPOCHS=10,\n"," T5_NAME='t5_small',\n"," TRAIN_VALID_FRAC=0.5,\n"," TRAINING_DIRECTORY = '/content/training_20230731_061334',\n"," TIMESTEPS=25,\n"," OPTIM_LR=0.0001,\n"," ACCUM_ITER=1,\n"," CHCKPT_NUM=500,\n"," VALID_NUM=None,\n"," RESTART_DIRECTORY=None,\n"," TESTING=False,\n"," timestamp=None,\n"," )\n","# Replace relevant values in arg dict\n","args.__dict__ = {**args.__dict__, **hyperparameters}"],"metadata":{"id":"8hdqSoAXoxqs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Data"],"metadata":{"id":"0CvubtDKEmHm"}},{"cell_type":"code","source":["# Load subset of Conceptual Captions dataset.\n","train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=False)\n","indices = torch.arange(1000)\n","\n","#create train and validation datasets with given number of samples\n","train_dataset = data_utils.Subset(train_dataset, indices)\n","valid_dataset = data_utils.Subset(valid_dataset, indices)\n","\n","# Create dataloaders\n","dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}\n","train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)\n","valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":105,"referenced_widgets":["f77bbc2602d846f8bb1c9e06f7b519ef","d2c87c66057f4e33bdbfe078ee47c0b2","fe33fa70d2c540e9831d408ffb5d3af9","61d7b89f706348548ebcaf0ced92a44e","a51b1054233342feaef0b16d2627a658","c4aa2aa530034af487957db71e9c509f","b4957dccdc5c4b83982212497febf4dc","592c8770c24c4366843336989c8cdb5f","f6c8ccfd74844b7ebafb37f181b51130","178935e2880043dc9eca862c09c05c16","aac7d4c2eebf4eb1914d2d98c52243ce"]},"id":"6LVG7NbZmNQq","executionInfo":{"status":"ok","timestamp":1690786766692,"user_tz":-330,"elapsed":9916,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7b91d590-fb10-4c7a-e432-a315c656d54a"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:datasets.builder:No config specified, defaulting to: conceptual_captions/unlabeled\n","WARNING:datasets.builder:Reusing dataset conceptual_captions (/root/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)\n"]},{"output_type":"display_data","data":{"text/plain":[" 0%| | 0/2 [00:00<?, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f77bbc2602d846f8bb1c9e06f7b519ef"}},"metadata":{}}]},{"cell_type":"markdown","source":["# UNet"],"metadata":{"id":"vmwLUZh2Eqhn"}},{"cell_type":"code","source":["# Instantiate Unet with default parameters and transfer to GPU if available\n","unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]\n","unets = [Unet(**unet_params).to(device) for unet_params in unets_params]"],"metadata":{"id":"IaKJG4IamPJD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify MinImagen parameters\n","imagen_params = dict(\n"," image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),\n"," timesteps=args.TIMESTEPS,\n"," cond_drop_prob=0.15,\n"," text_encoder_name=args.T5_NAME\n",")\n","\n","# Create MinImagen from UNets with specified imagen parameters\n","imagen = Imagen(unets=unets, **imagen_params).to(device)"],"metadata":{"id":"dl-w2Yy6mQ3Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Fill in unspecified arguments with defaults\n","unets_params = [{**get_default_args(Unet), **i} for i in unets_params]\n","imagen_params = {**get_default_args(Imagen), **imagen_params}\n","\n","# Get the size of the Imagen model in megabytes\n","model_size_MB = get_model_size(imagen)\n","\n","# Save all training info (config files, model size, etc.)\n","save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)"],"metadata":{"id":"tzkJfhuRmSqg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Training"],"metadata":{"id":"0XcZq-Q8EthC"}},{"cell_type":"code","source":["# Create optimizer - Adam\n","optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)\n","\n","# Train the MinImagen instance\n","MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)"],"metadata":{"id":"I--5Lt18mUf8","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788612546,"user_tz":-330,"elapsed":1835395,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"2e19e526-6526-473f-8e77-007e3df8425d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 1 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:05<00:22, 5.54s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:31<00:53, 17.84s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:45<00:31, 15.65s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:51<00:11, 11.95s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [01:06<00:00, 13.20s/it]\n","1it [01:14, 74.38s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.1316, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0483, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:27, 29.42s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 2 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:11<00:46, 11.63s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:26<00:40, 13.34s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:34<00:21, 10.96s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:41<00:09, 9.59s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:56<00:00, 35.30s/it]\n","1it [03:10, 190.22s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0965, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0373, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:12, 50.50s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 3 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:07<00:31, 7.80s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:13<00:19, 6.65s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:27<00:20, 10.07s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:32<00:07, 7.78s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:54<00:00, 34.83s/it]\n","1it [03:54, 234.78s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0735, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0289, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:21, 52.38s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 4 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:15<01:02, 15.63s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:23<00:32, 10.99s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:29<00:17, 8.55s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [02:48<01:00, 60.17s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:53<00:00, 34.62s/it]\n","1it [02:57, 177.30s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0502, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0210, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:07, 49.47s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 5 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:07<00:28, 7.14s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:31<00:52, 17.42s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:37<00:23, 11.92s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:43<00:09, 9.72s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:51<00:00, 10.29s/it]\n","1it [01:16, 76.25s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0274, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0135, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:32, 30.58s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 6 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:34<02:17, 34.45s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:45<01:01, 20.50s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:51<00:28, 14.07s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:56<00:10, 10.56s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [01:11<00:00, 14.37s/it]\n","1it [01:17, 77.59s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0088, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0083, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:25, 29.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 7 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:05<00:21, 5.37s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:18<00:29, 9.75s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:34<00:25, 12.60s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:49<00:13, 13.81s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:53<00:00, 10.75s/it]\n","1it [01:00, 60.32s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9863, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0049, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:07, 25.52s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 8 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:10<00:41, 10.40s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:15<00:22, 7.41s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:30<00:21, 10.65s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:41<00:10, 11.00s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:51<00:00, 10.24s/it]\n","1it [01:04, 64.48s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9715, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0007, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:05, 25.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 9 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:10<00:41, 10.36s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:16<00:23, 7.89s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:24<00:15, 7.73s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:35<00:09, 9.07s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:51<00:00, 34.28s/it]\n","1it [03:30, 210.85s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9587, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9981, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:11, 50.39s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 10 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:23<01:32, 23.20s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:33<00:46, 15.36s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:37<00:20, 10.14s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:43<00:08, 8.75s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:50<00:00, 10.04s/it]\n","1it [01:37, 97.50s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9483, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9955, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:03, 24.66s/it]\n"]}]},{"cell_type":"markdown","source":["# Inference"],"metadata":{"id":"BEnz_4zPEwgu"}},{"cell_type":"code","source":["from argparse import ArgumentParser\n","from minimagen.generate import load_minimagen, sample_and_save\n"],"metadata":{"id":"PUWUXmYDmdpm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify the caption(s) to generate images for\n","captions = ['happy']"],"metadata":{"id":"PPzAqX0qmeKa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["args_cls"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BGpmb7jamimu","executionInfo":{"status":"ok","timestamp":1690784383671,"user_tz":-330,"elapsed":445,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7972ae9b-1573-48de-d4be-923370eeb9e4"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["__main__.args_cls"]},"metadata":{},"execution_count":22}]},{"cell_type":"code","source":["# Use `sample_and_save` to generate and save the iamges\n","sample_and_save(captions, training_directory='/content/training_20230731_065902')"],"metadata":{"id":"fMxM5zdNmf8e","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788695591,"user_tz":-330,"elapsed":3817,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"01a1d7c8-bfcc-420b-e679-2f57525f8d30"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["0it [00:00, ?it/s]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 12%|ββ | 3/25 [00:00<00:01, 21.00it/s]\u001b[A\n","sampling loop time step: 24%|βββ | 6/25 [00:00<00:00, 20.08it/s]\u001b[A\n","sampling loop time step: 36%|ββββ | 9/25 [00:00<00:00, 20.17it/s]\u001b[A\n","sampling loop time step: 48%|βββββ | 12/25 [00:00<00:00, 20.03it/s]\u001b[A\n","sampling loop time step: 60%|ββββββ | 15/25 [00:00<00:00, 20.02it/s]\u001b[A\n","sampling loop time step: 72%|ββββββββ | 18/25 [00:00<00:00, 19.98it/s]\u001b[A\n","sampling loop time step: 80%|ββββββββ | 20/25 [00:01<00:00, 19.70it/s]\u001b[A\n","sampling loop time step: 88%|βββββββββ | 22/25 [00:01<00:00, 19.56it/s]\u001b[A\n","sampling loop time step: 100%|ββββββββββ| 25/25 [00:01<00:00, 19.65it/s]\n","1it [00:01, 1.28s/it]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 8%|β | 2/25 [00:00<00:01, 11.65it/s]\u001b[A\n","sampling loop time step: 16%|ββ | 4/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 24%|βββ | 6/25 [00:00<00:01, 11.53it/s]\u001b[A\n","sampling loop time step: 32%|ββββ | 8/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 40%|ββββ | 10/25 [00:00<00:01, 11.47it/s]\u001b[A\n","sampling loop time step: 48%|βββββ | 12/25 [00:01<00:01, 11.08it/s]\u001b[A\n","sampling loop time step: 56%|ββββββ | 14/25 [00:01<00:00, 11.27it/s]\u001b[A\n","sampling loop time step: 64%|βββββββ | 16/25 [00:01<00:00, 11.41it/s]\u001b[A\n","sampling loop time step: 72%|ββββββββ | 18/25 [00:01<00:00, 11.18it/s]\u001b[A\n","sampling loop time step: 80%|ββββββββ | 20/25 [00:01<00:00, 11.21it/s]\u001b[A\n","sampling loop time step: 88%|βββββββββ | 22/25 [00:01<00:00, 11.29it/s]\u001b[A\n","sampling loop time step: 100%|ββββββββββ| 25/25 [00:02<00:00, 11.27it/s]\n","2it [00:03, 1.76s/it]\n"]}]}]} |