asoria HF staff commited on
Commit
119b257
1 Parent(s): dfa9cba

Export PNG from plot

Browse files
Files changed (2) hide show
  1. app.py +62 -49
  2. requirements.txt +1 -0
app.py CHANGED
@@ -12,7 +12,7 @@ from bertopic import BERTopic
12
  from bertopic.representation import KeyBERTInspired
13
  from cuml.manifold import UMAP
14
  from cuml.cluster import HDBSCAN
15
-
16
  from sklearn.feature_extraction.text import CountVectorizer
17
  from sentence_transformers import SentenceTransformer
18
 
@@ -25,12 +25,10 @@ import gradio as gr
25
 
26
  """
27
  TODOs:
28
- - Improve DataMapPlot plot arguments
29
- - Add export button for final plot
30
- - Export and serve an interactive HTML plot?
31
  - Try with more rows
32
-
33
  - Add TextGenerationLayer
 
 
34
  - Make it run on Zero GPU
35
  """
36
 
@@ -38,6 +36,12 @@ load_dotenv()
38
  HF_TOKEN = os.getenv("HF_TOKEN")
39
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
40
 
 
 
 
 
 
 
41
  logging.basicConfig(
42
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
43
  )
@@ -145,6 +149,27 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
145
  logging.info("Global model updated")
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
149
  logging.info(
150
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
@@ -159,7 +184,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
159
 
160
  reduce_umap_model = UMAP(
161
  n_neighbors=n_neighbors,
162
- n_components=2, # For visualization, keeping it at 2 (2D)
163
  min_dist=0.0,
164
  metric="cosine",
165
  random_state=42,
@@ -183,6 +208,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
183
  gr.DataFrame(value=[], interactive=False, visible=True),
184
  gr.Plot(value=None, visible=True),
185
  gr.Label({message: rows_processed / limit}, visible=True),
 
186
  )
187
  while offset < limit:
188
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
@@ -216,59 +242,32 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
216
  topics_info = base_model.get_topic_info()
217
  all_topics, _ = base_model.transform(all_docs)
218
  all_topics = np.array(all_topics)
219
- # topic_plot, _ = datamapplot.create_plot(
220
- # data_map_coords=reduced_embeddings_array,
221
- # labels=all_topics.astype(str),
222
- # use_medoids=True,
223
- # figsize=(12, 12),
224
- # dpi=100,
225
- # title="PubMed - Literature review",
226
- # sub_title="A data map of papers representing artificial intelligence and machine learning in ophthalmology",
227
- # title_keywords={"fontsize": 36, "fontfamily": "Roboto Black"},
228
- # sub_title_keywords={
229
- # "fontsize": 18,
230
- # },
231
- # highlight_label_keywords={
232
- # "fontsize": 12,
233
- # "fontweight": "bold",
234
- # "bbox": {"boxstyle": "round"},
235
- # },
236
- # label_font_size=8,
237
- # label_wrap_width=16,
238
- # label_linespacing=1.25,
239
- # label_direction_bias=1.3,
240
- # label_margin_factor=2.0,
241
- # label_base_radius=15.0,
242
- # point_size=4,
243
- # marker_type="o",
244
- # arrowprops={
245
- # "arrowstyle": "wedge,tail_width=0.5",
246
- # "connectionstyle": "arc3,rad=0.05",
247
- # "linewidth": 0,
248
- # "fc": "#33333377",
249
- # },
250
- # add_glow=True,
251
- # glow_keywords={
252
- # "kernel_bandwidth": 0.75, # controls how wide the glow spreads.
253
- # "kernel": "cosine", # controls the kernel type. Default is "gaussian". See https://scikit-learn.org/stable/modules/density.html#kernel-density.
254
- # "n_levels": 32, # controls how many "levels" there are in the contour plot.
255
- # "max_alpha": 0.9, # controls the translucency of the glow.
256
- # },
257
- # darkmode=False,
258
- # )
259
 
260
  topic_plot = (
261
  base_model.visualize_document_datamap(
262
  docs=all_docs,
263
  reduced_embeddings=reduced_embeddings_array,
264
- title=f"<b>{dataset}</b>",
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
  if plot_type == "DataMapPlot"
267
  else base_model.visualize_documents(
268
  docs=all_docs,
269
  reduced_embeddings=reduced_embeddings_array,
270
  custom_labels=True,
271
- title=f"<b>{dataset}</b>",
272
  )
273
  )
274
 
@@ -286,12 +285,23 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
286
  topics_info,
287
  topic_plot,
288
  gr.Label({message: progress}, visible=True),
 
289
  )
290
 
291
  offset += CHUNK_SIZE
292
 
293
  logging.info("Finished processing all data")
294
 
 
 
 
 
 
 
 
 
 
 
295
  yield (
296
  gr.Accordion(open=False),
297
  topics_info,
@@ -299,6 +309,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
299
  gr.Label(
300
  {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
301
  ),
 
302
  )
303
  cuda.empty_cache()
304
 
@@ -339,7 +350,7 @@ with gr.Blocks() as demo:
339
  )
340
  plot_type_radio = gr.Radio(
341
  ["DataMapPlot", "Plotly"],
342
- value="Plotly",
343
  label="Choose the plot type",
344
  interactive=True,
345
  )
@@ -347,6 +358,7 @@ with gr.Blocks() as demo:
347
 
348
  gr.Markdown("## Data map")
349
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
 
350
  topics_plot = gr.Plot()
351
  with gr.Accordion("Topics Info", open=False):
352
  topics_df = gr.DataFrame(interactive=False, visible=True)
@@ -365,6 +377,7 @@ with gr.Blocks() as demo:
365
  topics_df,
366
  topics_plot,
367
  full_topics_generation_label,
 
368
  ],
369
  )
370
 
 
12
  from bertopic.representation import KeyBERTInspired
13
  from cuml.manifold import UMAP
14
  from cuml.cluster import HDBSCAN
15
+ from huggingface_hub import HfApi
16
  from sklearn.feature_extraction.text import CountVectorizer
17
  from sentence_transformers import SentenceTransformer
18
 
 
25
 
26
  """
27
  TODOs:
 
 
 
28
  - Try with more rows
 
29
  - Add TextGenerationLayer
30
+ - Try with more rows
31
+ - Export and serve an interactive HTML plot?
32
  - Make it run on Zero GPU
33
  """
34
 
 
36
  HF_TOKEN = os.getenv("HF_TOKEN")
37
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
38
 
39
+
40
+ EXPORTS_REPOSITORY = os.getenv("EXPORTS_REPOSITORY")
41
+ assert (
42
+ EXPORTS_REPOSITORY is not None
43
+ ), "You need to set EXPORTS_REPOSITORY in your environment variables"
44
+
45
  logging.basicConfig(
46
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
47
  )
 
149
  logging.info("Global model updated")
150
 
151
 
152
+ def _push_to_hub(
153
+ dataset_id,
154
+ file_path,
155
+ ):
156
+ logging.info(f"Pushing file to hub: {dataset_id} on file {file_path}")
157
+
158
+ file_name = file_path.split("/")[-1]
159
+ api = HfApi(token=HF_TOKEN)
160
+ try:
161
+ logging.info(f"About to push {file_path} - {dataset_id}")
162
+ api.upload_file(
163
+ path_or_fileobj=file_path,
164
+ path_in_repo=file_name,
165
+ repo_id=EXPORTS_REPOSITORY,
166
+ repo_type="dataset",
167
+ )
168
+ except Exception as e:
169
+ logging.info("Failed to push file", e)
170
+ raise
171
+
172
+
173
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
174
  logging.info(
175
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
 
184
 
185
  reduce_umap_model = UMAP(
186
  n_neighbors=n_neighbors,
187
+ n_components=2, # For visualization, keeping it for 2D
188
  min_dist=0.0,
189
  metric="cosine",
190
  random_state=42,
 
208
  gr.DataFrame(value=[], interactive=False, visible=True),
209
  gr.Plot(value=None, visible=True),
210
  gr.Label({message: rows_processed / limit}, visible=True),
211
+ "",
212
  )
213
  while offset < limit:
214
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
 
242
  topics_info = base_model.get_topic_info()
243
  all_topics, _ = base_model.transform(all_docs)
244
  all_topics = np.array(all_topics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  topic_plot = (
247
  base_model.visualize_document_datamap(
248
  docs=all_docs,
249
  reduced_embeddings=reduced_embeddings_array,
250
+ title=dataset,
251
+ width=800,
252
+ height=700,
253
+ # arrowprops={
254
+ # "arrowstyle": "wedge,tail_width=0.5",
255
+ # "connectionstyle": "arc3,rad=0.05",
256
+ # "linewidth": 0,
257
+ # "fc": "#33333377",
258
+ # },
259
+ label_wrap_width=12,
260
+ label_over_points=True,
261
+ dynamic_label_size=True,
262
+ max_font_size=36,
263
+ min_font_size=4,
264
  )
265
  if plot_type == "DataMapPlot"
266
  else base_model.visualize_documents(
267
  docs=all_docs,
268
  reduced_embeddings=reduced_embeddings_array,
269
  custom_labels=True,
270
+ title=dataset,
271
  )
272
  )
273
 
 
285
  topics_info,
286
  topic_plot,
287
  gr.Label({message: progress}, visible=True),
288
+ "",
289
  )
290
 
291
  offset += CHUNK_SIZE
292
 
293
  logging.info("Finished processing all data")
294
 
295
+ plot_png = f"{dataset.replace('/', '-')}-{plot_type.lower()}.png"
296
+ if plot_type == "DataMapPlot":
297
+ topic_plot.savefig(plot_png, format="png", dpi=300)
298
+ else:
299
+ topic_plot.write_image(plot_png)
300
+
301
+ _push_to_hub(dataset, plot_png)
302
+ plot_png_link = (
303
+ f"https://huggingface.co/datasets/{EXPORTS_REPOSITORY}/blob/main/{plot_png}"
304
+ )
305
  yield (
306
  gr.Accordion(open=False),
307
  topics_info,
 
309
  gr.Label(
310
  {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
311
  ),
312
+ f"[![Download as PNG](https://img.shields.io/badge/Download_as-PNG-red)]({plot_png_link})",
313
  )
314
  cuda.empty_cache()
315
 
 
350
  )
351
  plot_type_radio = gr.Radio(
352
  ["DataMapPlot", "Plotly"],
353
+ value="DataMapPlot",
354
  label="Choose the plot type",
355
  interactive=True,
356
  )
 
358
 
359
  gr.Markdown("## Data map")
360
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
361
+ open_png_label = gr.Markdown()
362
  topics_plot = gr.Plot()
363
  with gr.Accordion("Topics Info", open=False):
364
  topics_df = gr.DataFrame(interactive=False, visible=True)
 
377
  topics_df,
378
  topics_plot,
379
  full_topics_generation_label,
380
+ open_png_label,
381
  ],
382
  )
383
 
requirements.txt CHANGED
@@ -12,3 +12,4 @@ pandas
12
  torch
13
  numpy
14
  python-dotenv
 
 
12
  torch
13
  numpy
14
  python-dotenv
15
+ kaleido