Commit
·
2793e0d
1
Parent(s):
a3e02cb
Finalizing notes and type hints
Browse files- Data_Plotting/Plot_TSNE.py +11 -10
- app.py +1 -3
Data_Plotting/Plot_TSNE.py
CHANGED
@@ -6,15 +6,10 @@ import numpy as np
|
|
6 |
# Latent Feature Cluster for Training Data using T-SNE
|
7 |
def TSNE_reduction(latent_points: np.ndarray, perplexity=30, learning_rate=20):
|
8 |
"""
|
9 |
-
:param latent_points: [ndarray] - an array of arrays that define the points of
|
10 |
-
:param perplexity: [int] - default perplexity = 30 " Perplexity balances the attention t-SNE gives to local and
|
11 |
-
|
12 |
-
|
13 |
-
:param learning_rate: [int] - default learning rate = 200 "If the learning rate is too high, the data may look
|
14 |
-
like a ‘ball’ with any point approximately equidistant from its nearest neighbours.
|
15 |
-
If the learning rate is too low, most points may look compressed in a dense cloud with few outliers."
|
16 |
-
Recommended: learning_rate(10-1000)
|
17 |
-
:return: [tuple] - the output is the x and y coordinates for the reduced latent space, a title, and an embedding
|
18 |
"""
|
19 |
model = TSNE(n_components=2, random_state=0, perplexity=perplexity,
|
20 |
learning_rate=learning_rate)
|
@@ -31,8 +26,14 @@ def TSNE_reduction(latent_points: np.ndarray, perplexity=30, learning_rate=20):
|
|
31 |
|
32 |
|
33 |
def plot_dimensionality_reduction(x: list, y: list, label_set: list, title: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
plt.title(title)
|
35 |
-
# Color points based on
|
36 |
if label_set[0].dtype == float:
|
37 |
plt.scatter(x, y, c=label_set)
|
38 |
cbar = plt.colorbar()
|
|
|
6 |
# Latent Feature Cluster for Training Data using T-SNE
|
7 |
def TSNE_reduction(latent_points: np.ndarray, perplexity=30, learning_rate=20):
|
8 |
"""
|
9 |
+
:param latent_points: [ndarray] - an array of arrays that define the points of multiple objects in the latent space
|
10 |
+
:param perplexity: [int] - default perplexity = 30 " Perplexity balances the attention t-SNE gives to local and global aspects of the data. It is roughly a guess of the number of close neighbors each point has... a denser dataset ... requires higher perplexity value" Recommended: Perplexity(5-50)
|
11 |
+
:param learning_rate: [int] - default learning rate = 200 "If the learning rate is too high, the data may look like a ‘ball’ with any point approximately equidistant from its nearest neighbours. If the learning rate is too low, most points may look compressed in a dense cloud with few outliers." Recommended: learning_rate(10-1000)
|
12 |
+
:return: [tuple] - the output is the x and y coordinates for the reduced latent space, a title, and a TSNE embedding
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
model = TSNE(n_components=2, random_state=0, perplexity=perplexity,
|
15 |
learning_rate=learning_rate)
|
|
|
26 |
|
27 |
|
28 |
def plot_dimensionality_reduction(x: list, y: list, label_set: list, title: str):
|
29 |
+
"""
|
30 |
+
:param x: [list] - the first set of coordinates for each latent point
|
31 |
+
:param y: [list] - the second set of coordinates for each latent point
|
32 |
+
:param label_set: [list] - a set of values that define the color of each point based on an additional quantitative attribute.
|
33 |
+
:return: matplotlib figure - the output is a matplotlib figure that displays all the points in a 2-dimensional latent space, based on the labels provided.
|
34 |
+
"""
|
35 |
plt.title(title)
|
36 |
+
# Color points based on a continuous label
|
37 |
if label_set[0].dtype == float:
|
38 |
plt.scatter(x, y, c=label_set)
|
39 |
cbar = plt.colorbar()
|
app.py
CHANGED
@@ -89,9 +89,7 @@ if st.button('Generate Dataset'): # Generate the dataset
|
|
89 |
plt.figure(3)
|
90 |
# set the color values for the plot
|
91 |
plot_dimensionality_reduction(x, y, avg_density, title)
|
92 |
-
|
93 |
-
# plt.scatter(x, y)
|
94 |
-
plt.title(title)
|
95 |
plt.figure(3)
|
96 |
st.pyplot(plt.figure(3))
|
97 |
|
|
|
89 |
plt.figure(3)
|
90 |
# set the color values for the plot
|
91 |
plot_dimensionality_reduction(x, y, avg_density, title)
|
92 |
+
# plt.title(title)
|
|
|
|
|
93 |
plt.figure(3)
|
94 |
st.pyplot(plt.figure(3))
|
95 |
|