{ "cells": [ { "cell_type": "code", "execution_count": 7, "id": "eea9f1b8-4240-4a68-a412-2b4071b2c04a", "metadata": {}, "outputs": [], "source": [ "from drawdata import ScatterWidget" ] }, { "cell_type": "code", "execution_count": 8, "id": "207dbcfd-e731-4758-8035-d1f429aa10d4", "metadata": {}, "outputs": [], "source": [ "widget = ScatterWidget()" ] }, { "cell_type": "code", "execution_count": 13, "id": "b77030f4-c895-4a39-96ce-b79d6a8a6d69", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from IPython.core.display import HTML\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.inspection import DecisionBoundaryDisplay\n", "from sklearn.tree import DecisionTreeClassifier\n", "\n", "import matplotlib.pylab as plt \n", "import numpy as np\n", "import ipywidgets" ] }, { "cell_type": "code", "execution_count": null, "id": "b9d36c79-3d1d-4084-a1a6-43f3b95c06fe", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b21ddf1e80a4a34a786547610d52aa0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(ScatterWidget(), Output()))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widget = ScatterWidget()\n", "output = ipywidgets.Output()\n", "\n", "\n", "@output.capture(clear_output=True)\n", "def on_change(change):\n", " df = widget.data_as_pandas\n", " if len(df) and (df['color'].nunique() > 1):\n", " X = df[['x', 'y']].values\n", " y = df['color']\n", " display(HTML(\"


\"))\n", " fig = plt.figure(figsize=(12, 12));\n", " classifier = DecisionTreeClassifier().fit(X, y)\n", " disp = DecisionBoundaryDisplay.from_estimator(\n", " classifier, X, \n", " ax=fig.add_subplot(111),\n", " response_method=\"predict_proba\" if len(np.unique(df['color'])) == 2 else \"predict\",\n", " xlabel=\"x\", ylabel=\"y\",\n", " alpha=0.5,\n", " );\n", " disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor=\"k\");\n", " plt.title(f\"{classifier.__class__.__name__}\");\n", " disp.ax_.set_title(f\"{classifier.__class__.__name__}\");\n", " plt.show();\n", "widget.observe(on_change, names=[\"data\"])\n", "on_change(None)\n", "page = ipywidgets.HBox([widget, output])\n", "page" ] }, { "cell_type": "markdown", "id": "5f9bb7ce-8a7f-4879-9ddf-5b256bb0ff64", "metadata": {}, "source": [ "\n", "p











" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }