import streamlit as st import nltk from transformers import pipeline from sentence_transformers import SentenceTransformer from scipy.spatial.distance import cosine import numpy as np import seaborn as sns import matplotlib.pyplot as plt from sklearn.cluster import KMeans import tensorflow as tf import tensorflow_hub as hub def cluster_examples(messages, embed, nc=3): km = KMeans( n_clusters=nc, init='random', n_init=10, max_iter=300, tol=1e-04, random_state=0 ) km = km.fit_predict(embed) for n in range(nc): idxs = [i for i in range(len(km)) if km[i] == n] ms = [messages[i] for i in idxs] st.markdown ("CLUSTER : %d"%n) for m in ms: st.markdown (m) def plot_heatmap(labels, heatmap, rotation=90): sns.set(font_scale=1.2) fig, ax = plt.subplots() g = sns.heatmap( heatmap, xticklabels=labels, yticklabels=labels, vmin=-1, vmax=1, cmap="coolwarm") g.set_xticklabels(labels, rotation=rotation) g.set_title("Textual Similarity") st.pyplot(fig) #plt.show() #st.header("Sentence Similarity Demo") # Streamlit text boxes text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.") nc = st.slider('Select a number of clusters:', min_value=1, max_value=15, value=3) model_type = st.radio("Choose model:", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0) # Model setup if model_type == "Sentence Transformer": model = SentenceTransformer('paraphrase-distilroberta-base-v1') elif model_type == "Universal Sentence Encoder": model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5" model = hub.load(model_url) nltk.download('punkt') # Run model if text: sentences = nltk.tokenize.sent_tokenize(text) if model_type == "Sentence Transformer": embed = model.encode(sentences) elif model_type == "Universal Sentence Encoder": embed = model(sentences).numpy() sim = np.zeros([len(embed), len(embed)]) for i,em in enumerate(embed): for j,ea in enumerate(embed): sim[i][j] = 1.0-cosine(em,ea) st.subheader("Similarity Heatmap") plot_heatmap(sentences, sim) st.subheader("Results from K-Means Clustering") cluster_examples(sentences, embed, nc)