{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "zlD1ysQ7Tckp" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from pathlib import Path" ] }, { "cell_type": "code", "source": [ "!pip3 install gradio" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YJQBJHP1Tym5", "outputId": "46afdd2e-b5fc-4eef-93e3-1f3bc9b7f1ce" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting gradio\n", " Downloading gradio-3.38.0-py3-none-any.whl (19.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.8/19.8 MB\u001b[0m \u001b[31m72.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)\n", " Downloading aiofiles-23.1.0-py3-none-any.whl (14 kB)\n", "Requirement already satisfied: aiohttp~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.8.4)\n", "Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)\n", "Collecting fastapi (from gradio)\n", " Downloading fastapi-0.100.0-py3-none-any.whl (65 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.7/65.7 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting ffmpy (from gradio)\n", " Downloading ffmpy-0.3.1.tar.gz (5.5 kB)\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Collecting gradio-client>=0.2.10 (from gradio)\n", " Downloading gradio_client-0.2.10-py3-none-any.whl (288 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m289.0/289.0 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting httpx (from gradio)\n", " Downloading httpx-0.24.1-py3-none-any.whl (75 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.4/75.4 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting huggingface-hub>=0.14.0 (from gradio)\n", " Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m28.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)\n", "Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.0.0)\n", "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)\n", "Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n", "Collecting mdit-py-plugins<=0.3.3 (from gradio)\n", " Downloading mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.5/50.5 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy~=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.22.4)\n", "Collecting orjson~=3.0 (from gradio)\n", " Downloading orjson-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.7/138.7 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (23.1)\n", "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)\n", "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (8.4.0)\n", "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.10.11)\n", "Collecting pydub (from gradio)\n", " Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n", "Collecting python-multipart (from gradio)\n", " Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n", "Requirement already satisfied: requests~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.27.1)\n", "Collecting semantic-version~=2.0 (from gradio)\n", " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n", "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.7.1)\n", "Collecting uvicorn>=0.14.0 (from gradio)\n", " Downloading uvicorn-0.23.1-py3-none-any.whl (59 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.5/59.5 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting websockets<12.0,>=10.0 (from gradio)\n", " Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (23.1.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (2.0.12)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp~=3.0->gradio) (1.3.1)\n", "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4)\n", "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.3.3)\n", "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client>=0.2.10->gradio) (2023.6.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (3.12.2)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (4.65.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2)\n", "Collecting linkify-it-py<3,>=1 (from markdown-it-py[linkify]>=2.0.0->gradio)\n", " Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.1.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.11.0)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.41.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.4)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.0)\n", "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2)\n", "INFO: pip is looking at multiple versions of mdit-py-plugins to determine which version is compatible with other requirements. This could take a while.\n", "Collecting mdit-py-plugins<=0.3.3 (from gradio)\n", " Downloading mdit_py_plugins-0.3.2-py3-none-any.whl (50 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.4/50.4 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.3.1-py3-none-any.whl (46 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.5/46.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.3.0-py3-none-any.whl (43 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.7/43.7 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.8-py3-none-any.whl (41 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.7-py3-none-any.whl (41 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.6-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.5-py3-none-any.whl (39 kB)\n", "INFO: pip is looking at multiple versions of mdit-py-plugins to determine which version is compatible with other requirements. This could take a while.\n", " Downloading mdit_py_plugins-0.2.4-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.3-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.2-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.1-py3-none-any.whl (38 kB)\n", " Downloading mdit_py_plugins-0.2.0-py3-none-any.whl (38 kB)\n", "INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.\n", " Downloading mdit_py_plugins-0.1.0-py3-none-any.whl (37 kB)\n", "Collecting markdown-it-py[linkify]>=2.0.0 (from gradio)\n", " Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.5/87.5 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading markdown_it_py-2.2.0-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2022.7.1)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2023.5.7)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.4)\n", "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (8.1.6)\n", "Collecting h11>=0.8 (from uvicorn>=0.14.0->gradio)\n", " Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting starlette<0.28.0,>=0.27.0 (from fastapi->gradio)\n", " Downloading starlette-0.27.0-py3-none-any.whl (66 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting httpcore<0.18.0,>=0.15.0 (from httpx->gradio)\n", " Downloading httpcore-0.17.3-py3-none-any.whl (74 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.5/74.5 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)\n", "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio) (3.7.1)\n", "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.19.3)\n", "Collecting uc-micro-py (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio)\n", " Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n", "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx->gradio) (1.1.2)\n", "Building wheels for collected packages: ffmpy\n", " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for ffmpy: filename=ffmpy-0.3.1-py3-none-any.whl size=5579 sha256=1f000f2f7853794c5857deb6ffb3305dbe3ceedf8e546a42ffd49709d11a0496\n", " Stored in directory: /root/.cache/pip/wheels/01/a6/d1/1c0828c304a4283b2c1639a09ad86f83d7c487ef34c6b4a1bf\n", "Successfully built ffmpy\n", "Installing collected packages: pydub, ffmpy, websockets, uc-micro-py, semantic-version, python-multipart, orjson, markdown-it-py, h11, aiofiles, uvicorn, starlette, mdit-py-plugins, linkify-it-py, huggingface-hub, httpcore, httpx, fastapi, gradio-client, gradio\n", " Attempting uninstall: markdown-it-py\n", " Found existing installation: markdown-it-py 3.0.0\n", " Uninstalling markdown-it-py-3.0.0:\n", " Successfully uninstalled markdown-it-py-3.0.0\n", "Successfully installed aiofiles-23.1.0 fastapi-0.100.0 ffmpy-0.3.1 gradio-3.38.0 gradio-client-0.2.10 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 huggingface-hub-0.16.4 linkify-it-py-2.0.2 markdown-it-py-2.2.0 mdit-py-plugins-0.3.3 orjson-3.9.2 pydub-0.25.1 python-multipart-0.0.6 semantic-version-2.10.0 starlette-0.27.0 uc-micro-py-1.0.2 uvicorn-0.23.1 websockets-11.0.3\n" ] } ] }, { "cell_type": "code", "source": [ "LABELS=Path('class_names.txt').read_text().splitlines()" ], "metadata": { "id": "PGdL4zXMT8Bn" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "model = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(32, 64, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(64, 128, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Flatten(),\n", " nn.Linear(1152, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, len(LABELS)),\n", ")" ], "metadata": { "id": "XZ1AM4BxUMnH" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "state_dict=torch.load('pytorch_model.bin',map_location='cpu')" ], "metadata": { "id": "lDJZ7Vo9Ui0Y" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "import gradio as gr" ], "metadata": { "id": "LzwKoPfTUq5x" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "model.load_state_dict(state_dict, strict=False)\n", "model.eval()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e_6paL_qU4YB", "outputId": "e018df13-6ac1-47a0-d634-cb8084d05d5a" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Sequential(\n", " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", " (4): ReLU()\n", " (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", " (7): ReLU()\n", " (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (9): Flatten(start_dim=1, end_dim=-1)\n", " (10): Linear(in_features=1152, out_features=256, bias=True)\n", " (11): ReLU()\n", " (12): Linear(in_features=256, out_features=100, bias=True)\n", ")" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "def predict(im):\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "\n", " with torch.no_grad():\n", " out = model(x)\n", "\n", " probabilities = torch.nn.functional.softmax(out[0], dim=0)\n", "\n", " values, indices = torch.topk(probabilities, 5)\n", "\n", " return {LABELS[i]: v.item() for i, v in zip(indices, values)}" ], "metadata": { "id": "baptDAiGVCvF" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "interface=gr.Interface(\n", " predict,\n", " inputs=\"sketchpad\",\n", " outputs='label',\n", " theme=\"huggingface\",\n", " title=\"Sketch Recognition\",\n", " description=\"Sketch something for the model to guess in realtime\",\n", " article = \"

Sketch Recognition | Demo Model

\",\n", " live=True)\n", "interface.launch(share=True)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "4sZiRCv-VI9O", "outputId": "476b68b7-397e-4465-fb83-fb9b6dff499d" }, "execution_count": 13, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py:694: UserWarning: Cannot load huggingface. Caught Exception: The space huggingface does not exist\n", " warnings.warn(f\"Cannot load {theme}. Caught Exception: {str(e)}\")\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n", "Running on public URL: https://033d7a67112f637955.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/routes.py\", line 442, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1389, in process_api\n", " result = await self.call_function(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1094, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 703, in wrapper\n", " response = f(*args, **kwargs)\n", " File \"\", line 2, in predict\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "TypeError: must be real number, not NoneType\n", "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/routes.py\", line 442, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1389, in process_api\n", " result = await self.call_function(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1094, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 703, in wrapper\n", " response = f(*args, **kwargs)\n", " File \"\", line 2, in predict\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "TypeError: must be real number, not NoneType\n", "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/routes.py\", line 442, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1389, in process_api\n", " result = await self.call_function(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1094, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 703, in wrapper\n", " response = f(*args, **kwargs)\n", " File \"\", line 2, in predict\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "TypeError: must be real number, not NoneType\n", "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/routes.py\", line 442, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1389, in process_api\n", " result = await self.call_function(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1094, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 703, in wrapper\n", " response = f(*args, **kwargs)\n", " File \"\", line 2, in predict\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "TypeError: must be real number, not NoneType\n", "Traceback (most recent call last):\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/routes.py\", line 442, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1389, in process_api\n", " result = await self.call_function(\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1094, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 703, in wrapper\n", " response = f(*args, **kwargs)\n", " File \"\", line 2, in predict\n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "TypeError: must be real number, not NoneType\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Keyboard interruption in main thread... closing server.\n", "Killing tunnel 127.0.0.1:7860 <> https://033d7a67112f637955.gradio.live\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "30uwS_0zVhT3" }, "execution_count": null, "outputs": [] } ] }