edugp commited on
Commit
8faba1d
·
1 Parent(s): 75fc948

Add spinners

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -98,7 +98,6 @@ def generate_plot(
98
  dimensionality_reduction_function: Callable,
99
  model: SentenceTransformer,
100
  ) -> Figure:
101
- logger.info("Loading dataset in memory")
102
  if text_column not in df.columns:
103
  raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
104
  if label_column not in df.columns:
@@ -106,12 +105,12 @@ def generate_plot(
106
  df = df.dropna(subset=[text_column, label_column])
107
  if sample:
108
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
109
- logger.info("Embedding sentences")
110
- embeddings = embed_text(df[text_column].values.tolist(), model)
111
  logger.info("Encoding labels")
112
  encoded_labels = encode_labels(df[label_column])
113
- logger.info("Running dimensionality reduction")
114
- embeddings_2d = dimensionality_reduction_function(embeddings)
115
  logger.info("Generating figure")
116
  plot = draw_interactive_scatter_plot(
117
  df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
@@ -136,14 +135,16 @@ label_column = st.text_input("Numerical/categorical column name (ignore if not a
136
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
137
  dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
138
  model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
139
- model = load_model(model_name)
 
140
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
141
 
142
  if uploaded_file or hub_dataset:
143
- if uploaded_file:
144
- df = uploaded_file_to_dataframe(uploaded_file)
145
- else:
146
- df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
 
147
  plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
148
  logger.info("Displaying plot")
149
  st.bokeh_chart(plot)
 
98
  dimensionality_reduction_function: Callable,
99
  model: SentenceTransformer,
100
  ) -> Figure:
 
101
  if text_column not in df.columns:
102
  raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
103
  if label_column not in df.columns:
 
105
  df = df.dropna(subset=[text_column, label_column])
106
  if sample:
107
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
108
+ with st.spinner(text='Embedding text...'):
109
+ embeddings = embed_text(df[text_column].values.tolist(), model)
110
  logger.info("Encoding labels")
111
  encoded_labels = encode_labels(df[label_column])
112
+ with st.spinner("Reducing dimensionality..."):
113
+ embeddings_2d = dimensionality_reduction_function(embeddings)
114
  logger.info("Generating figure")
115
  plot = draw_interactive_scatter_plot(
116
  df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
 
135
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
136
  dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
137
  model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
138
+ with st.spinner(text="Loading model..."):
139
+ model = load_model(model_name)
140
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
141
 
142
  if uploaded_file or hub_dataset:
143
+ with st.spinner("Loading dataset..."):
144
+ if uploaded_file:
145
+ df = uploaded_file_to_dataframe(uploaded_file)
146
+ else:
147
+ df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
148
  plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
149
  logger.info("Displaying plot")
150
  st.bokeh_chart(plot)