import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets import make_blobs def get_clusters_plot(n_blobs, quantile, cluster_std): X, _, centers = make_blobs( n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True ) bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500) ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(X) labels = ms.labels_ cluster_centers = ms.cluster_centers_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) fig = plt.figure() for k in range(n_clusters_): my_members = labels == k cluster_center = cluster_centers[k] plt.scatter(X[my_members, 0], X[my_members, 1]) plt.plot( cluster_center[0], cluster_center[1], "x", markeredgecolor="k", markersize=14, ) plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.title(f"Estimated number of clusters: {n_clusters_}") if len(centers) != n_clusters_: message = ( '
' + f"The number of estimated clusters ({n_clusters_})" + f" differs from the true number of clusters ({n_blobs})." + " Try changing the `Quantile` parameter.
" ) else: message = ( '' + f"The number of estimated clusters ({n_clusters_})" + f" matches the true number of clusters ({n_blobs})!
" ) return fig, message with gr.Blocks() as demo: gr.Markdown( """ # Mean Shift Clustering This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs. This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py) """ ) with gr.Row(): with gr.Column(scale=1): n_blobs = gr.Slider( minimum=2, maximum=10, label="Number of clusters in the data", step=1, value=3, ) quantile = gr.Slider( minimum=0, maximum=1, step=0.05, value=0.2, label="Quantile", info="Used to determine clustering's bandwidth.", ) cluster_std = gr.Slider( minimum=0.1, maximum=1, label="Clusters' standard deviation", step=0.1, value=0.6, ) with gr.Column(scale=4): clusters_plots = gr.Plot(label="Clusters' Plot") message = gr.HTML() n_blobs.change( get_clusters_plot, [n_blobs, quantile, cluster_std], [clusters_plots, message], queue=False, ) quantile.change( get_clusters_plot, [n_blobs, quantile, cluster_std], [clusters_plots, message], queue=False, ) cluster_std.change( get_clusters_plot, [n_blobs, quantile, cluster_std], [clusters_plots, message], queue=False, ) demo.load( get_clusters_plot, [n_blobs, quantile, cluster_std], [clusters_plots, message], queue=False, ) if __name__ == "__main__": demo.launch()