jharrison27 commited on
Commit
fe78a4f
·
1 Parent(s): 67cc9a7
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -11,15 +11,24 @@ mock_words = [
11
  "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
- # Define available models
15
  models = {
16
  'DistilBERT': 'distilbert-base-uncased',
17
  'BERT': 'bert-base-uncased',
18
  'RoBERTa': 'roberta-base'
19
  }
20
 
 
 
 
 
 
 
 
 
 
21
  def embed_words(words, model_name):
22
- embedder = pipeline('feature-extraction', model=model_name)
23
  embeddings = embedder(words)
24
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
25
 
@@ -51,10 +60,9 @@ def main():
51
 
52
  # Dropdown menu for selecting the embedding model
53
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
54
- selected_model = models[model_name]
55
 
56
  if st.button("Generate Clusters"):
57
- clusters = cluster_words(mock_words, selected_model)
58
  display_clusters(clusters)
59
 
60
  if __name__ == "__main__":
 
11
  "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
+ # Define available models and load them
15
  models = {
16
  'DistilBERT': 'distilbert-base-uncased',
17
  'BERT': 'bert-base-uncased',
18
  'RoBERTa': 'roberta-base'
19
  }
20
 
21
+ @st.cache(allow_output_mutation=True)
22
+ def load_models():
23
+ pipelines = {}
24
+ for name, model_name in models.items():
25
+ pipelines[name] = pipeline('feature-extraction', model=model_name)
26
+ return pipelines
27
+
28
+ pipelines = load_models()
29
+
30
  def embed_words(words, model_name):
31
+ embedder = pipelines[model_name]
32
  embeddings = embedder(words)
33
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
34
 
 
60
 
61
  # Dropdown menu for selecting the embedding model
62
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
 
63
 
64
  if st.button("Generate Clusters"):
65
+ clusters = cluster_words(mock_words, model_name)
66
  display_clusters(clusters)
67
 
68
  if __name__ == "__main__":