Spaces:
Runtime error
Runtime error
# Scikit learn example https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html | |
import gradio as gr | |
from sklearn.cluster import OPTICS, cluster_optics_dbscan | |
import matplotlib.gridspec as gridspec | |
import matplotlib.pyplot as plt | |
import numpy as np | |
plt.switch_backend("agg") | |
# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Open Sans"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
) | |
def do_submit(n_points_per_cluster, min_samples, xi, min_cluster_size): | |
# # Generate sample data | |
np.random.seed(0) | |
n_points_per_cluster = int(n_points_per_cluster) | |
C1 = [-5, -2] + 0.8 * np.random.randn(n_points_per_cluster, 2) | |
C2 = [4, -1] + 0.1 * np.random.randn(n_points_per_cluster, 2) | |
C3 = [1, -2] + 0.2 * np.random.randn(n_points_per_cluster, 2) | |
C4 = [-2, 3] + 0.3 * np.random.randn(n_points_per_cluster, 2) | |
C5 = [3, -2] + 1.6 * np.random.randn(n_points_per_cluster, 2) | |
C6 = [5, 6] + 2 * np.random.randn(n_points_per_cluster, 2) | |
X = np.vstack((C1, C2, C3, C4, C5, C6)) | |
clust = OPTICS( | |
min_samples=int(min_samples), | |
xi=float(xi), | |
min_cluster_size=float(min_cluster_size), | |
) | |
# Run the fit | |
clust.fit(X) | |
labels_050 = cluster_optics_dbscan( | |
reachability=clust.reachability_, | |
core_distances=clust.core_distances_, | |
ordering=clust.ordering_, | |
eps=0.5, | |
) | |
labels_200 = cluster_optics_dbscan( | |
reachability=clust.reachability_, | |
core_distances=clust.core_distances_, | |
ordering=clust.ordering_, | |
eps=2, | |
) | |
space = np.arange(len(X)) | |
reachability = clust.reachability_[clust.ordering_] | |
labels = clust.labels_[clust.ordering_] | |
plt.figure(figsize=(10, 6)) | |
G = gridspec.GridSpec(2, 3) | |
ax1 = plt.subplot(G[0, :]) | |
ax2 = plt.subplot(G[1, 0]) | |
ax3 = plt.subplot(G[1, 1]) | |
ax4 = plt.subplot(G[1, 2]) | |
# Reachability plot | |
colors = ["g.", "r.", "b.", "y.", "c."] | |
for klass, color in zip(range(0, 5), colors): | |
Xk = space[labels == klass] | |
Rk = reachability[labels == klass] | |
ax1.plot(Xk, Rk, color, alpha=0.3) | |
ax1.plot(space[labels == -1], reachability[labels == -1], "k.", alpha=0.3) | |
ax1.plot(space, np.full_like(space, 2.0, dtype=float), "k-", alpha=0.5) | |
ax1.plot(space, np.full_like(space, 0.5, dtype=float), "k-.", alpha=0.5) | |
ax1.set_ylabel("Reachability (epsilon distance)") | |
ax1.set_title("Reachability Plot") | |
# OPTICS | |
colors = ["g.", "r.", "b.", "y.", "c."] | |
for klass, color in zip(range(0, 5), colors): | |
Xk = X[clust.labels_ == klass] | |
ax2.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) | |
ax2.plot(X[clust.labels_ == -1, 0], X[clust.labels_ == -1, 1], "k+", alpha=0.1) | |
ax2.set_title("Automatic Clustering\nOPTICS") | |
# DBSCAN at 0.5 | |
colors = ["g.", "r.", "b.", "c."] | |
for klass, color in zip(range(0, 4), colors): | |
Xk = X[labels_050 == klass] | |
ax3.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) | |
ax3.plot(X[labels_050 == -1, 0], X[labels_050 == -1, 1], "k+", alpha=0.1) | |
ax3.set_title("Clustering at 0.5 epsilon cut\nDBSCAN") | |
# DBSCAN at 2. | |
colors = ["g.", "m.", "y.", "c."] | |
for klass, color in zip(range(0, 4), colors): | |
Xk = X[labels_200 == klass] | |
ax4.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) | |
ax4.plot(X[labels_200 == -1, 0], X[labels_200 == -1, 1], "k+", alpha=0.1) | |
ax4.set_title("Clustering at 2.0 epsilon cut\nDBSCAN") | |
plt.tight_layout() | |
return plt | |
title = "Demo of OPTICS clustering algorithm" | |
with gr.Blocks(title=title, theme=theme) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown( | |
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html)" | |
) | |
gr.Markdown( | |
"Finds core samples of high density and expands clusters from them. This example uses data that is \ | |
generated so that the clusters have different densities. The [OPTICS](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS) is first used with its Xi cluster detection \ | |
method, and then setting specific thresholds on the reachability, which corresponds to [DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN). We can see that \ | |
the different clusters of OPTICS’s Xi method can be recovered with different choices of thresholds in DBSCAN." | |
) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=0.75): | |
n_points_per_cluster = gr.Slider( | |
minimum=200, | |
maximum=500, | |
label="Number of points per cluster", | |
step=50, | |
value=250, | |
) | |
with gr.Row(visible=False): | |
gr.Markdown("##") | |
min_samples = gr.Slider( | |
minimum=10, | |
maximum=100, | |
label="OPTICS - Minimum number of samples", | |
step=5, | |
value=50, | |
info="The number of samples in a neighborhood for a point to be considered as a core point.", | |
) | |
with gr.Row(visible=False): | |
gr.Markdown("##") | |
xi = gr.Slider( | |
minimum=0, | |
maximum=0.2, | |
label="OPTICS - Xi", | |
step=0.01, | |
value=0.05, | |
info="Determines the minimum steepness on the reachability plot that constitutes a cluster boundary. ", | |
) | |
with gr.Row(visible=False): | |
gr.Markdown("##") | |
min_cluster_size = gr.Slider( | |
minimum=0.01, | |
maximum=0.1, | |
label="OPTICS - Minimum cluster size", | |
step=0.01, | |
value=0.05, | |
info="Minimum number of samples in an OPTICS cluster, expressed as an absolute number or a fraction of the number of samples (rounded to be at least 2).", | |
) | |
plt_out = gr.Plot() | |
n_points_per_cluster.change( | |
do_submit, | |
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], | |
outputs=plt_out, | |
) | |
min_samples.change( | |
do_submit, | |
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], | |
outputs=plt_out, | |
) | |
xi.change( | |
do_submit, | |
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], | |
outputs=plt_out, | |
) | |
min_cluster_size.change( | |
do_submit, | |
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], | |
outputs=plt_out, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |