asoria HF staff commited on
Commit
9e7becb
1 Parent(s): 18889e4

Adding plot type selection

Browse files
Files changed (1) hide show
  1. app.py +86 -12
app.py CHANGED
@@ -1,29 +1,39 @@
1
- import requests
2
  import logging
 
 
 
3
  import duckdb
4
  import numpy as np
 
 
5
  from torch import cuda
6
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
7
  from bertopic import BERTopic
8
  from bertopic.representation import KeyBERTInspired
9
-
10
  from cuml.manifold import UMAP
11
  from cuml.cluster import HDBSCAN
12
 
13
  from sklearn.feature_extraction.text import CountVectorizer
14
-
15
  from sentence_transformers import SentenceTransformer
16
 
17
  from dotenv import load_dotenv
18
- import os
19
 
 
20
  # import spaces
21
  import gradio as gr
22
 
23
 
24
  """
25
  TODOs:
26
- - Try for small dataset <1000 rows
 
 
 
 
 
 
 
 
27
  """
28
 
29
  load_dotenv()
@@ -137,7 +147,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
137
  logging.info("Global model updated")
138
 
139
 
140
- def generate_topics(dataset, config, split, column, nested_column):
141
  logging.info(
142
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
143
  )
@@ -202,12 +212,65 @@ def generate_topics(dataset, config, split, column, nested_column):
202
  reduced_embeddings_list.append(reduced_embeddings)
203
 
204
  all_docs.extend(docs)
 
205
 
206
  topics_info = base_model.get_topic_info()
207
- topic_plot = base_model.visualize_documents(
208
- all_docs,
209
- reduced_embeddings=np.vstack(reduced_embeddings_list),
210
- custom_labels=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
 
213
  rows_processed += len(docs)
@@ -228,6 +291,7 @@ def generate_topics(dataset, config, split, column, nested_column):
228
  offset += CHUNK_SIZE
229
 
230
  logging.info("Finished processing all data")
 
231
  yield (
232
  topics_info,
233
  topic_plot,
@@ -271,7 +335,12 @@ with gr.Blocks() as demo:
271
  nested_text_column_dropdown = gr.Dropdown(
272
  label="Nested text column name", visible=False
273
  )
274
-
 
 
 
 
 
275
  generate_button = gr.Button("Generate Topics", variant="primary")
276
 
277
  gr.Markdown("## Data map")
@@ -287,8 +356,13 @@ with gr.Blocks() as demo:
287
  split_dropdown,
288
  text_column_dropdown,
289
  nested_text_column_dropdown,
 
 
 
 
 
 
290
  ],
291
- outputs=[topics_df, topics_plot, full_topics_generation_label],
292
  )
293
 
294
  def _resolve_dataset_selection(
 
 
1
  import logging
2
+ import os
3
+
4
+ import datamapplot
5
  import duckdb
6
  import numpy as np
7
+ import requests
8
+
9
  from torch import cuda
10
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
11
  from bertopic import BERTopic
12
  from bertopic.representation import KeyBERTInspired
 
13
  from cuml.manifold import UMAP
14
  from cuml.cluster import HDBSCAN
15
 
16
  from sklearn.feature_extraction.text import CountVectorizer
 
17
  from sentence_transformers import SentenceTransformer
18
 
19
  from dotenv import load_dotenv
 
20
 
21
+ # These imports at the end because of torch/datamapplot issue in Zero GPU
22
  # import spaces
23
  import gradio as gr
24
 
25
 
26
  """
27
  TODOs:
28
+ - Hide params panel when generating plot
29
+
30
+ - Improve DataMapPlot plot arguments
31
+ - Add export button for final plot
32
+ - Export and serve an interactive HTML plot?
33
+ - Try with more rows
34
+
35
+ - Add TextGenerationLayer
36
+ - Make it run on Zero GPU
37
  """
38
 
39
  load_dotenv()
 
147
  logging.info("Global model updated")
148
 
149
 
150
+ def generate_topics(dataset, config, split, column, nested_column, plot_type):
151
  logging.info(
152
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
153
  )
 
212
  reduced_embeddings_list.append(reduced_embeddings)
213
 
214
  all_docs.extend(docs)
215
+ reduced_embeddings_array = np.vstack(reduced_embeddings_list)
216
 
217
  topics_info = base_model.get_topic_info()
218
+ all_topics, _ = base_model.transform(all_docs)
219
+ all_topics = np.array(all_topics)
220
+ # topic_plot, _ = datamapplot.create_plot(
221
+ # data_map_coords=reduced_embeddings_array,
222
+ # labels=all_topics.astype(str),
223
+ # use_medoids=True,
224
+ # figsize=(12, 12),
225
+ # dpi=100,
226
+ # title="PubMed - Literature review",
227
+ # sub_title="A data map of papers representing artificial intelligence and machine learning in ophthalmology",
228
+ # title_keywords={"fontsize": 36, "fontfamily": "Roboto Black"},
229
+ # sub_title_keywords={
230
+ # "fontsize": 18,
231
+ # },
232
+ # highlight_label_keywords={
233
+ # "fontsize": 12,
234
+ # "fontweight": "bold",
235
+ # "bbox": {"boxstyle": "round"},
236
+ # },
237
+ # label_font_size=8,
238
+ # label_wrap_width=16,
239
+ # label_linespacing=1.25,
240
+ # label_direction_bias=1.3,
241
+ # label_margin_factor=2.0,
242
+ # label_base_radius=15.0,
243
+ # point_size=4,
244
+ # marker_type="o",
245
+ # arrowprops={
246
+ # "arrowstyle": "wedge,tail_width=0.5",
247
+ # "connectionstyle": "arc3,rad=0.05",
248
+ # "linewidth": 0,
249
+ # "fc": "#33333377",
250
+ # },
251
+ # add_glow=True,
252
+ # glow_keywords={
253
+ # "kernel_bandwidth": 0.75, # controls how wide the glow spreads.
254
+ # "kernel": "cosine", # controls the kernel type. Default is "gaussian". See https://scikit-learn.org/stable/modules/density.html#kernel-density.
255
+ # "n_levels": 32, # controls how many "levels" there are in the contour plot.
256
+ # "max_alpha": 0.9, # controls the translucency of the glow.
257
+ # },
258
+ # darkmode=False,
259
+ # )
260
+
261
+ topic_plot = (
262
+ base_model.visualize_document_datamap(
263
+ docs=all_docs,
264
+ reduced_embeddings=reduced_embeddings_array,
265
+ title=f"<b>{dataset}</b>",
266
+ )
267
+ if plot_type == "DataMapPlot"
268
+ else base_model.visualize_documents(
269
+ docs=all_docs,
270
+ reduced_embeddings=reduced_embeddings_array,
271
+ custom_labels=True,
272
+ title=f"<b>{dataset}</b>",
273
+ )
274
  )
275
 
276
  rows_processed += len(docs)
 
291
  offset += CHUNK_SIZE
292
 
293
  logging.info("Finished processing all data")
294
+
295
  yield (
296
  topics_info,
297
  topic_plot,
 
335
  nested_text_column_dropdown = gr.Dropdown(
336
  label="Nested text column name", visible=False
337
  )
338
+ plot_type_radio = gr.Radio(
339
+ ["DataMapPlot", "Plotly"],
340
+ value="Plotly",
341
+ label="Choose the plot type",
342
+ interactive=True,
343
+ )
344
  generate_button = gr.Button("Generate Topics", variant="primary")
345
 
346
  gr.Markdown("## Data map")
 
356
  split_dropdown,
357
  text_column_dropdown,
358
  nested_text_column_dropdown,
359
+ plot_type_radio,
360
+ ],
361
+ outputs=[
362
+ topics_df,
363
+ topics_plot,
364
+ full_topics_generation_label,
365
  ],
 
366
  )
367
 
368
  def _resolve_dataset_selection(