Spaces:
Runtime error
Runtime error
import streamlit as st | |
import nltk | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from scipy.spatial.distance import cosine | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from sklearn.cluster import KMeans | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
def cluster_examples(messages, embed, nc=3): | |
km = KMeans( | |
n_clusters=nc, init='random', | |
n_init=10, max_iter=300, | |
tol=1e-04, random_state=0 | |
) | |
km = km.fit_predict(embed) | |
for n in range(nc): | |
idxs = [i for i in range(len(km)) if km[i] == n] | |
ms = [messages[i] for i in idxs] | |
st.markdown ("CLUSTER : %d"%n) | |
for m in ms: | |
st.markdown (m) | |
def plot_heatmap(labels, heatmap, rotation=90): | |
sns.set(font_scale=1.2) | |
fig, ax = plt.subplots() | |
g = sns.heatmap( | |
heatmap, | |
xticklabels=labels, | |
yticklabels=labels, | |
vmin=-1, | |
vmax=1, | |
cmap="coolwarm") | |
g.set_xticklabels(labels, rotation=rotation) | |
g.set_title("Textual Similarity") | |
st.pyplot(fig) | |
#plt.show() | |
#st.header("Sentence Similarity Demo") | |
# Streamlit text boxes | |
text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.") | |
nc = st.slider('Select a number of clusters:', min_value=1, max_value=15, value=3) | |
model_type = st.radio("Choose model:", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0) | |
# Model setup | |
if model_type == "Sentence Transformer": | |
model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
elif model_type == "Universal Sentence Encoder": | |
model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5" | |
model = hub.load(model_url) | |
nltk.download('punkt') | |
# Run model | |
if text: | |
sentences = nltk.tokenize.sent_tokenize(text) | |
if model_type == "Sentence Transformer": | |
embed = model.encode(sentences) | |
elif model_type == "Universal Sentence Encoder": | |
embed = model(sentences).numpy() | |
sim = np.zeros([len(embed), len(embed)]) | |
for i,em in enumerate(embed): | |
for j,ea in enumerate(embed): | |
sim[i][j] = 1.0-cosine(em,ea) | |
st.subheader("Similarity Heatmap") | |
plot_heatmap(sentences, sim) | |
st.subheader("Results from K-Means Clustering") | |
cluster_examples(sentences, embed, nc) | |