edugp commited on
Commit
76bc904
·
1 Parent(s): 8faba1d

Add flax-sentence-embeddings/all_datasets_v3_mpnet-base model

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -16,6 +16,8 @@ from sklearn.manifold import TSNE
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
 
 
19
  SEED = 0
20
 
21
 
@@ -105,7 +107,7 @@ def generate_plot(
105
  df = df.dropna(subset=[text_column, label_column])
106
  if sample:
107
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
108
- with st.spinner(text='Embedding text...'):
109
  embeddings = embed_text(df[text_column].values.tolist(), model)
110
  logger.info("Encoding labels")
111
  encoded_labels = encode_labels(df[label_column])
@@ -133,8 +135,8 @@ with col3:
133
  text_column = st.text_input("Text column name", "text")
134
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
135
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
136
- dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
137
- model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
138
  with st.spinner(text="Loading model..."):
139
  model = load_model(model_name)
140
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
+ EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"]
20
+ DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"]
21
  SEED = 0
22
 
23
 
 
107
  df = df.dropna(subset=[text_column, label_column])
108
  if sample:
109
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
110
+ with st.spinner(text="Embedding text..."):
111
  embeddings = embed_text(df[text_column].values.tolist(), model)
112
  logger.info("Encoding labels")
113
  encoded_labels = encode_labels(df[label_column])
 
135
  text_column = st.text_input("Text column name", "text")
136
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
137
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
138
+ dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
139
+ model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
140
  with st.spinner(text="Loading model..."):
141
  model = load_model(model_name)
142
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings