from typing import List from traitlets import Dict import solara import solara.lab import matplotlib.pyplot as plt # needed for solara up to version 1.28 plt.switch_backend("module://matplotlib_inline.backend_inline") # title = "solara" from sklearn.linear_model import LogisticRegression from sklearn.inspection import DecisionBoundaryDisplay from sklearn.tree import DecisionTreeClassifier import matplotlib.pylab as plt import numpy as np import pandas as pd from drawdata import ScatterWidget drawdata: solara.Reactive[List[Dict]] = solara.reactive([]) # we keep the active tab in a reactive var so the state does not get lost when we change # the orientation of the page (vertical or horizontal) tab = solara.reactive(0) @solara.component def ClassifierDraw(classifier, X, y, response_method="predict_proba", figsize=(8, 8)): fig = plt.figure(figsize=figsize) disp = DecisionBoundaryDisplay.from_estimator( classifier, X, # not sure why this was needed, otherwise i get a blank plot ax=fig.add_subplot(111), response_method=response_method, xlabel="x", ylabel="y", alpha=0.5, ) disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k") plt.title(f"{classifier.__class__.__name__}") plt.close() solara.FigureMatplotlib(fig) @solara.component def DecisionTreeClassifierDraw(df): criterion = solara.use_reactive("gini") splitter = solara.use_reactive("best") with solara.Row(): solara.ToggleButtonsSingle(value=criterion, values=["gini", "entropy", "log_loss"]) solara.ToggleButtonsSingle(value=splitter, values=["best", "random"]) X = df[["x", "y"]].values y = df["color"] classifier = DecisionTreeClassifier(criterion=criterion.value, splitter=splitter.value).fit(X, y) ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict") @solara.component def LogisticRegressionDraw(df): penalty = solara.use_reactive("l2") solver = solara.use_reactive("lbfgs") l1_ratio = solara.use_reactive(0.5) with solara.Row(): solara.ToggleButtonsSingle(value=penalty, values=["l1", "l2", "elasticnet", "none"]) solara.ToggleButtonsSingle(value=solver, values=["newton-cg", "lbfgs", "liblinear", "sag", "saga"]) if penalty.value == "elasticnet": solara.FloatSlider("l1_ratio", value=l1_ratio, min=0, max=1, step=0.1) X = df[["x", "y"]].values y = df["color"] try: classifier = LogisticRegression(penalty=penalty.value, solver=solver.value, l1_ratio=l1_ratio.value).fit(X, y) except ValueError as e: solara.Error(str(e)) else: ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict") @solara.component def Page(): vertical = solara.use_reactive(True) solara.AppBarTitle("Draw Data with Solara demo") df = pd.DataFrame(drawdata.value) if drawdata.value else None with solara.AppBar(): # TODO: doesn't work, ScatterWidget does not update when data is updated (read only?) # solara.Button(icon_name="mdi-delete", on_click=lambda: drawdata.set([]), icon=True) # demo how solara can dynamically change the layout solara.lab.ThemeToggle(enable_auto=False) solara.Button(icon_name="mdi-align-vertical-top" if vertical.value else "mdi-align-horizontal-left", on_click=lambda: vertical.set(not vertical.value), icon=True) dark_background = solara.lab.use_dark_effective() plt.style.use('dark_background' if dark_background else 'default') with solara.Column() if vertical.value else solara.Row(): # with solara, we don't just create the widget, but an element that describes it # and instead of observe, we have on_ callbacks # Note: if we store the data in the reactive var (drawdata), we keep the drawing # on hot reload. ScatterWidget.element(data=drawdata.value, on_data=drawdata.set) # downside of using elements and components: we cannot call method on the widget # so we need to re-create the dataframe ourselves with solara.lab.Tabs(value=tab): with solara.lab.Tab("classifier"): with solara.Column(classes=["py-4"]): # some nice y padding if df is not None and (df["color"].nunique() > 1): with solara.Column(style={"max-height": "500px", "padding-top": "0px"}): with solara.lab.Tabs(): with solara.lab.Tab("DecisionTreeClassifier"): DecisionTreeClassifierDraw(df) with solara.lab.Tab("LogisticRegressionDraw"): LogisticRegressionDraw(df) else: with solara.Column(style={"justify-content": "center"}) if not vertical.value else solara.Row(): solara.Info("Choose at least two colors to draw a decision boundary.") with solara.lab.Tab("table view"): with solara.Column(classes=["py-4"]): # some nice y padding if df is not None: with solara.FileDownload(data=lambda: df.to_csv(), filename="drawdata.csv"): solara.Button("download as csv", icon_name="mdi-download", outlined=True, color="primary") solara.DataFrame(df) # in the notebook: Page()