asoria HF staff commited on
Commit
2b40426
β€’
1 Parent(s): 24bed82

Try to fix plot

Browse files
Files changed (1) hide show
  1. app.py +208 -216
app.py CHANGED
@@ -37,6 +37,7 @@ DATASETS_TOPICS_ORGANIZATION = os.getenv(
37
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
38
  )
39
  USE_CUML = int(os.getenv("USE_CUML", "1"))
 
40
 
41
  # Use cuml lib only if configured
42
  if USE_CUML:
@@ -52,17 +53,19 @@ logging.basicConfig(
52
  )
53
 
54
  api = HfApi(token=HF_TOKEN)
 
55
 
 
56
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
57
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
58
- vectorizer_model = CountVectorizer(stop_words="english")
59
  representation_model = KeyBERTInspired()
 
60
 
61
  inference_client = InferenceClient(model_id)
62
 
63
 
64
  def calculate_embeddings(docs):
65
- return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
66
 
67
 
68
  def calculate_n_neighbors_and_components(n_rows):
@@ -92,7 +95,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
92
  new_model = BERTopic(
93
  language="english",
94
  # Sub-models
95
- embedding_model=embedding_model, # Step 1 - Extract embeddings
96
  umap_model=umap_model, # Step 2 - UMAP model
97
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
98
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
@@ -166,146 +169,44 @@ def generate_topics(dataset, config, split, column, plot_type):
166
  "",
167
  )
168
 
169
- try:
170
- while offset < limit:
171
- logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
172
- docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
173
- if not docs:
174
- break
175
- logging.info(f"Got {len(docs)} docs βœ“")
176
- embeddings = calculate_embeddings(docs)
177
- new_model = fit_model(docs, embeddings, n_neighbors, n_components)
178
-
179
- if base_model is None:
180
- base_model = new_model
181
- logging.info(
182
- f"The following topics are newly found: {base_model.topic_labels_}"
183
- )
184
- else:
185
- updated_model = BERTopic.merge_models([base_model, new_model])
186
- nr_new_topics = len(set(updated_model.topics_)) - len(
187
- set(base_model.topics_)
188
- )
189
- new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
190
- logging.info(f"The following topics are newly found: {new_topics}")
191
- base_model = updated_model
192
-
193
- logging.info("Reducing embeddings to 2D")
194
- reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
195
- reduced_embeddings_list.append(reduced_embeddings)
196
-
197
- all_docs.extend(docs)
198
- reduced_embeddings_array = np.vstack(reduced_embeddings_list)
199
- logging.info("Reducing embeddings to 2D βœ“")
200
-
201
- topics_info = base_model.get_topic_info()
202
- all_topics = base_model.topics_
203
- logging.info(f"Preparing topics {plot_type} plot")
204
-
205
- topic_plot = (
206
- base_model.visualize_document_datamap(
207
- docs=all_docs,
208
- topics=all_topics,
209
- reduced_embeddings=reduced_embeddings_array,
210
- title="",
211
- sub_title=sub_title,
212
- width=800,
213
- height=700,
214
- arrowprops={
215
- "arrowstyle": "wedge,tail_width=0.5",
216
- "connectionstyle": "arc3,rad=0.05",
217
- "linewidth": 0,
218
- "fc": "#33333377",
219
- },
220
- dynamic_label_size=True,
221
- # label_wrap_width=12,
222
- label_over_points=True,
223
- max_font_size=36,
224
- min_font_size=4,
225
- )
226
- if plot_type == "DataMapPlot"
227
- else base_model.visualize_documents(
228
- docs=all_docs,
229
- topics=all_topics,
230
- reduced_embeddings=reduced_embeddings_array,
231
- custom_labels=True,
232
- title="",
233
- )
234
- )
235
- logging.info("Plot done βœ“")
236
- rows_processed += len(docs)
237
- progress = min(rows_processed / limit, 1.0)
238
- logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
239
- message = (
240
- f"Processing topics for full dataset: {rows_processed} of {limit}"
241
- if full_processing
242
- else f"Processing topics for partial dataset: {rows_processed} of {limit} rows"
243
  )
244
-
245
- yield (
246
- gr.Accordion(open=False),
247
- topics_info,
248
- topic_plot,
249
- gr.Label({"⏳ " + message: progress}, visible=True),
250
- "",
251
  )
 
 
 
252
 
253
- offset += CHUNK_SIZE
254
- del docs, embeddings, new_model, reduced_embeddings
255
- logging.info("Finished processing topic modeling data")
256
-
257
- yield (
258
- gr.Accordion(open=False),
259
- topics_info,
260
- topic_plot,
261
- gr.Label(
262
- {
263
- "βœ… " + message: 1.0,
264
- f"⏳ Generating topic names with {model_id}": 0.0,
265
- },
266
- visible=True,
267
- ),
268
- "",
269
- )
270
-
271
- all_topics = base_model.topics_
272
- topics_info = base_model.get_topic_info()
273
 
274
- new_topics_by_text_generation = {}
275
- for _, row in topics_info.iterrows():
276
- logging.info(
277
- f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
278
- )
279
- prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
280
- prompt_messages = [
281
- {
282
- "role": "system",
283
- "content": "You are a helpful, respectful and honest assistant for labeling topics.",
284
- },
285
- {"role": "user", "content": prompt},
286
- ]
287
- output = inference_client.chat_completion(
288
- messages=prompt_messages,
289
- stream=False,
290
- max_tokens=500,
291
- top_p=0.8,
292
- seed=42,
293
- )
294
- inference_response = output.choices[0].message.content
295
- logging.info("Inference response:")
296
- logging.info(inference_response)
297
- new_topics_by_text_generation[row["Topic"]] = inference_response.replace(
298
- "Topic=", ""
299
- ).strip()
300
- base_model.set_topic_labels(new_topics_by_text_generation)
301
 
302
  topics_info = base_model.get_topic_info()
303
-
 
304
  topic_plot = (
305
  base_model.visualize_document_datamap(
306
  docs=all_docs,
307
  topics=all_topics,
308
- custom_labels=True,
309
  reduced_embeddings=reduced_embeddings_array,
310
  title="",
311
  sub_title=sub_title,
@@ -326,100 +227,191 @@ def generate_topics(dataset, config, split, column, plot_type):
326
  if plot_type == "DataMapPlot"
327
  else base_model.visualize_documents(
328
  docs=all_docs,
 
329
  reduced_embeddings=reduced_embeddings_array,
330
- custom_labels=True,
331
  title="",
332
  )
333
  )
 
 
 
 
 
 
 
 
 
334
 
335
- dataset_clear_name = dataset.replace("/", "-")
336
- plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
337
- if plot_type == "DataMapPlot":
338
- topic_plot.savefig(plot_png, format="png", dpi=300)
339
- else:
340
- topic_plot.write_image(plot_png)
341
-
342
- custom_labels = base_model.custom_labels_
343
- topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
344
  yield (
345
  gr.Accordion(open=False),
346
  topics_info,
347
  topic_plot,
348
- gr.Label(
349
- {
350
- "βœ… " + message: 1.0,
351
- f"βœ… Generating topic names with {model_id}": 1.0,
352
- "⏳ Creating Interactive Space": 0.0,
353
- },
354
- visible=True,
355
- ),
356
  "",
357
  )
358
- interactive_plot = datamapplot.create_interactive_plot(
359
- reduced_embeddings_array,
360
- topic_names_array,
361
- hover_text=all_docs,
362
- title=dataset,
363
- sub_title=sub_title.replace(
364
- "dataset",
365
- f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
366
- ),
367
- enable_search=True,
368
- # TODO: Export data to .arrow and also serve it
369
- inline_data=True,
370
- # offline_data_prefix=dataset_clear_name,
371
- initial_zoom_fraction=0.8,
372
- )
373
- html_content = str(interactive_plot)
374
- html_file_path = f"{dataset_clear_name}.html"
375
- with open(html_file_path, "w", encoding="utf-8") as html_file:
376
- html_file.write(html_content)
377
-
378
- repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
379
-
380
- space_id = create_space_with_content(
381
- api=api,
382
- repo_id=repo_id,
383
- dataset_id=dataset,
384
- html_file_path=html_file_path,
385
- plot_file_path=plot_png,
386
- space_card=SPACE_REPO_CARD_CONTENT,
387
- token=HF_TOKEN,
388
- )
389
 
390
- space_link = f"https://huggingface.co/spaces/{space_id}"
 
 
391
 
392
- yield (
393
- gr.Accordion(open=False),
394
- topics_info,
395
- topic_plot,
396
- gr.Label(
397
- {
398
- "βœ… " + message: 1.0,
399
- f"βœ… Generating topic names with {model_id}": 1.0,
400
- "βœ… Creating Interactive Space": 1.0,
401
- },
402
- visible=True,
403
- ),
404
- f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
 
 
 
 
 
 
 
 
405
  )
406
- del reduce_umap_model, all_docs, reduced_embeddings_list
407
- del (
408
- base_model,
409
- all_topics,
410
- topics_info,
411
- topic_names_array,
412
- interactive_plot,
 
 
 
 
 
 
 
413
  )
414
- cuda.empty_cache()
415
- except Exception as error:
416
- return (
417
- gr.Accordion(open=True),
418
- gr.DataFrame(value=[], interactive=False, visible=True),
419
- gr.Plot(value=None, visible=True),
420
- gr.Label({f"❌ Error: {error}": 0.0}, visible=True),
421
- "",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
 
425
  with gr.Blocks() as demo:
@@ -468,11 +460,11 @@ with gr.Blocks() as demo:
468
  generate_button = gr.Button("Generate Topics", variant="primary")
469
 
470
  gr.Markdown("## Data map")
471
- progress_label = gr.Label(visible=False, show_label=False)
472
  open_space_label = gr.Markdown()
473
  topics_plot = gr.Plot()
474
- # with gr.Accordion("Topics Info", open=False):
475
- topics_df = gr.DataFrame(interactive=False, visible=True)
476
  gr.HTML(
477
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
478
  )
@@ -494,7 +486,7 @@ with gr.Blocks() as demo:
494
  data_details_accordion,
495
  topics_df,
496
  topics_plot,
497
- progress_label,
498
  open_space_label,
499
  ],
500
  )
 
37
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
38
  )
39
  USE_CUML = int(os.getenv("USE_CUML", "1"))
40
+ USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
41
 
42
  # Use cuml lib only if configured
43
  if USE_CUML:
 
53
  )
54
 
55
  api = HfApi(token=HF_TOKEN)
56
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
57
 
58
+ # Representation model
59
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
60
+
 
61
  representation_model = KeyBERTInspired()
62
+ vectorizer_model = CountVectorizer(stop_words="english")
63
 
64
  inference_client = InferenceClient(model_id)
65
 
66
 
67
  def calculate_embeddings(docs):
68
+ return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
69
 
70
 
71
  def calculate_n_neighbors_and_components(n_rows):
 
95
  new_model = BERTopic(
96
  language="english",
97
  # Sub-models
98
+ embedding_model=sentence_model, # Step 1 - Extract embeddings
99
  umap_model=umap_model, # Step 2 - UMAP model
100
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
101
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
 
169
  "",
170
  )
171
 
172
+ while offset < limit:
173
+ logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
174
+ docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
175
+ if not docs:
176
+ break
177
+ logging.info(f"Got {len(docs)} docs βœ“")
178
+ embeddings = calculate_embeddings(docs)
179
+ new_model = fit_model(docs, embeddings, n_neighbors, n_components)
180
+
181
+ if base_model is None:
182
+ base_model = new_model
183
+ logging.info(
184
+ f"The following topics are newly found: {base_model.topic_labels_}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
+ else:
187
+ updated_model = BERTopic.merge_models([base_model, new_model])
188
+ nr_new_topics = len(set(updated_model.topics_)) - len(
189
+ set(base_model.topics_)
 
 
 
190
  )
191
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
192
+ logging.info(f"The following topics are newly found: {new_topics}")
193
+ base_model = updated_model
194
 
195
+ logging.info("Reducing embeddings to 2D")
196
+ reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
197
+ reduced_embeddings_list.append(reduced_embeddings)
198
+ logging.info("Reducing embeddings to 2D βœ“")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ all_docs.extend(docs)
201
+ reduced_embeddings_array = np.vstack(reduced_embeddings_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  topics_info = base_model.get_topic_info()
204
+ all_topics = base_model.topics_
205
+ logging.info(f"Preparing topics {plot_type} plot")
206
  topic_plot = (
207
  base_model.visualize_document_datamap(
208
  docs=all_docs,
209
  topics=all_topics,
 
210
  reduced_embeddings=reduced_embeddings_array,
211
  title="",
212
  sub_title=sub_title,
 
227
  if plot_type == "DataMapPlot"
228
  else base_model.visualize_documents(
229
  docs=all_docs,
230
+ topics=all_topics,
231
  reduced_embeddings=reduced_embeddings_array,
 
232
  title="",
233
  )
234
  )
235
+ logging.info("Plot done βœ“")
236
+ rows_processed += len(docs)
237
+ progress = min(rows_processed / limit, 1.0)
238
+ logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
239
+ message = (
240
+ f"Processing topics for full dataset: {rows_processed} of {limit}"
241
+ if full_processing
242
+ else f"Processing topics for partial dataset: {rows_processed} of {limit} rows"
243
+ )
244
 
 
 
 
 
 
 
 
 
 
245
  yield (
246
  gr.Accordion(open=False),
247
  topics_info,
248
  topic_plot,
249
+ gr.Label({"⏳ " + message: progress}, visible=True),
 
 
 
 
 
 
 
250
  "",
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ offset += CHUNK_SIZE
254
+ del docs, embeddings, new_model, reduced_embeddings
255
+ logging.info("Finished processing all data")
256
 
257
+ yield (
258
+ gr.Accordion(open=False),
259
+ topics_info,
260
+ topic_plot,
261
+ gr.Label(
262
+ {
263
+ "βœ… " + message: 1.0,
264
+ f"⏳ Generating topic names with {model_id}": 0.0,
265
+ },
266
+ visible=True,
267
+ ),
268
+ "",
269
+ )
270
+
271
+ all_topics = base_model.topics_
272
+ topics_info = base_model.get_topic_info()
273
+
274
+ new_topics_by_text_generation = {}
275
+ for _, row in topics_info.iterrows():
276
+ logging.info(
277
+ f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
278
  )
279
+ prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
280
+ prompt_messages = [
281
+ {
282
+ "role": "system",
283
+ "content": "You are a helpful, respectful and honest assistant for labeling topics.",
284
+ },
285
+ {"role": "user", "content": prompt},
286
+ ]
287
+ output = inference_client.chat_completion(
288
+ messages=prompt_messages,
289
+ stream=False,
290
+ max_tokens=500,
291
+ top_p=0.8,
292
+ seed=42,
293
  )
294
+ inference_response = output.choices[0].message.content
295
+ logging.info("Inference response:")
296
+ logging.info(inference_response)
297
+ new_topics_by_text_generation[row["Topic"]] = inference_response.replace(
298
+ "Topic=", ""
299
+ ).strip()
300
+ base_model.set_topic_labels(new_topics_by_text_generation)
301
+
302
+ topics_info = base_model.get_topic_info()
303
+
304
+ topic_plot = (
305
+ base_model.visualize_document_datamap(
306
+ docs=all_docs,
307
+ topics=all_topics,
308
+ custom_labels=True,
309
+ reduced_embeddings=reduced_embeddings_array,
310
+ title="",
311
+ sub_title=sub_title,
312
+ width=800,
313
+ height=700,
314
+ arrowprops={
315
+ "arrowstyle": "wedge,tail_width=0.5",
316
+ "connectionstyle": "arc3,rad=0.05",
317
+ "linewidth": 0,
318
+ "fc": "#33333377",
319
+ },
320
+ dynamic_label_size=True,
321
+ # label_wrap_width=12,
322
+ label_over_points=True,
323
+ max_font_size=36,
324
+ min_font_size=4,
325
  )
326
+ if plot_type == "DataMapPlot"
327
+ else base_model.visualize_documents(
328
+ docs=all_docs,
329
+ reduced_embeddings=reduced_embeddings_array,
330
+ custom_labels=True,
331
+ title="",
332
+ )
333
+ )
334
+
335
+ dataset_clear_name = dataset.replace("/", "-")
336
+ plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
337
+ if plot_type == "DataMapPlot":
338
+ topic_plot.savefig(plot_png, format="png", dpi=300)
339
+ else:
340
+ topic_plot.write_image(plot_png)
341
+
342
+ custom_labels = base_model.custom_labels_
343
+ topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
344
+ yield (
345
+ gr.Accordion(open=False),
346
+ topics_info,
347
+ topic_plot,
348
+ gr.Label(
349
+ {
350
+ "βœ… " + message: 1.0,
351
+ f"βœ… Generating topic names with {model_id}": 1.0,
352
+ "⏳ Creating Interactive Space": 0.0,
353
+ },
354
+ visible=True,
355
+ ),
356
+ "",
357
+ )
358
+ interactive_plot = datamapplot.create_interactive_plot(
359
+ reduced_embeddings_array,
360
+ topic_names_array,
361
+ hover_text=all_docs,
362
+ title=dataset,
363
+ sub_title=sub_title.replace(
364
+ "dataset",
365
+ f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
366
+ ),
367
+ enable_search=True,
368
+ # TODO: Export data to .arrow and also serve it
369
+ inline_data=True,
370
+ # offline_data_prefix=dataset_clear_name,
371
+ initial_zoom_fraction=0.8,
372
+ )
373
+ html_content = str(interactive_plot)
374
+ html_file_path = f"{dataset_clear_name}.html"
375
+ with open(html_file_path, "w", encoding="utf-8") as html_file:
376
+ html_file.write(html_content)
377
+
378
+ repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
379
+
380
+ space_id = create_space_with_content(
381
+ api=api,
382
+ repo_id=repo_id,
383
+ dataset_id=dataset,
384
+ html_file_path=html_file_path,
385
+ plot_file_path=plot_png,
386
+ space_card=SPACE_REPO_CARD_CONTENT,
387
+ token=HF_TOKEN,
388
+ )
389
+
390
+ space_link = f"https://huggingface.co/spaces/{space_id}"
391
+ yield (
392
+ gr.Accordion(open=False),
393
+ topics_info,
394
+ topic_plot,
395
+ gr.Label(
396
+ {
397
+ "βœ… " + message: 1.0,
398
+ f"βœ… Generating topic names with {model_id}": 1.0,
399
+ "βœ… Creating Interactive Space": 1.0,
400
+ },
401
+ visible=True,
402
+ ),
403
+ f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
404
+ )
405
+ del reduce_umap_model, all_docs, reduced_embeddings_list
406
+ del (
407
+ base_model,
408
+ all_topics,
409
+ topics_info,
410
+ topic_plot,
411
+ topic_names_array,
412
+ interactive_plot,
413
+ )
414
+ cuda.empty_cache()
415
 
416
 
417
  with gr.Blocks() as demo:
 
460
  generate_button = gr.Button("Generate Topics", variant="primary")
461
 
462
  gr.Markdown("## Data map")
463
+ full_topics_generation_label = gr.Label(visible=False, show_label=False)
464
  open_space_label = gr.Markdown()
465
  topics_plot = gr.Plot()
466
+ with gr.Accordion("Topics Info", open=False):
467
+ topics_df = gr.DataFrame(interactive=False, visible=True)
468
  gr.HTML(
469
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
470
  )
 
486
  data_details_accordion,
487
  topics_df,
488
  topics_plot,
489
+ full_topics_generation_label,
490
  open_space_label,
491
  ],
492
  )