Spaces:
Runtime error
Runtime error
Commit
·
fe78a4f
1
Parent(s):
67cc9a7
update
Browse files
app.py
CHANGED
@@ -11,15 +11,24 @@ mock_words = [
|
|
11 |
"cat", "dog", "rabbit", "hamster" # Pets
|
12 |
]
|
13 |
|
14 |
-
# Define available models
|
15 |
models = {
|
16 |
'DistilBERT': 'distilbert-base-uncased',
|
17 |
'BERT': 'bert-base-uncased',
|
18 |
'RoBERTa': 'roberta-base'
|
19 |
}
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def embed_words(words, model_name):
|
22 |
-
embedder =
|
23 |
embeddings = embedder(words)
|
24 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
25 |
|
@@ -51,10 +60,9 @@ def main():
|
|
51 |
|
52 |
# Dropdown menu for selecting the embedding model
|
53 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
54 |
-
selected_model = models[model_name]
|
55 |
|
56 |
if st.button("Generate Clusters"):
|
57 |
-
clusters = cluster_words(mock_words,
|
58 |
display_clusters(clusters)
|
59 |
|
60 |
if __name__ == "__main__":
|
|
|
11 |
"cat", "dog", "rabbit", "hamster" # Pets
|
12 |
]
|
13 |
|
14 |
+
# Define available models and load them
|
15 |
models = {
|
16 |
'DistilBERT': 'distilbert-base-uncased',
|
17 |
'BERT': 'bert-base-uncased',
|
18 |
'RoBERTa': 'roberta-base'
|
19 |
}
|
20 |
|
21 |
+
@st.cache(allow_output_mutation=True)
|
22 |
+
def load_models():
|
23 |
+
pipelines = {}
|
24 |
+
for name, model_name in models.items():
|
25 |
+
pipelines[name] = pipeline('feature-extraction', model=model_name)
|
26 |
+
return pipelines
|
27 |
+
|
28 |
+
pipelines = load_models()
|
29 |
+
|
30 |
def embed_words(words, model_name):
|
31 |
+
embedder = pipelines[model_name]
|
32 |
embeddings = embedder(words)
|
33 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
34 |
|
|
|
60 |
|
61 |
# Dropdown menu for selecting the embedding model
|
62 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
|
|
63 |
|
64 |
if st.button("Generate Clusters"):
|
65 |
+
clusters = cluster_words(mock_words, model_name)
|
66 |
display_clusters(clusters)
|
67 |
|
68 |
if __name__ == "__main__":
|