asoria HF staff commited on
Commit
24bed82
1 Parent(s): bf92466

Adding logs

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -15,6 +15,7 @@ from bertopic.representation import KeyBERTInspired
15
  from huggingface_hub import HfApi, InferenceClient
16
  from sklearn.feature_extraction.text import CountVectorizer
17
  from sentence_transformers import SentenceTransformer
 
18
 
19
  from src.hub import create_space_with_content
20
  from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT
@@ -167,14 +168,11 @@ def generate_topics(dataset, config, split, column, plot_type):
167
 
168
  try:
169
  while offset < limit:
 
170
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
171
  if not docs:
172
  break
173
-
174
- logging.info(
175
- f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs"
176
- )
177
-
178
  embeddings = calculate_embeddings(docs)
179
  new_model = fit_model(docs, embeddings, n_neighbors, n_components)
180
 
@@ -192,14 +190,18 @@ def generate_topics(dataset, config, split, column, plot_type):
192
  logging.info(f"The following topics are newly found: {new_topics}")
193
  base_model = updated_model
194
 
 
195
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
196
  reduced_embeddings_list.append(reduced_embeddings)
197
 
198
  all_docs.extend(docs)
199
  reduced_embeddings_array = np.vstack(reduced_embeddings_list)
 
200
 
201
  topics_info = base_model.get_topic_info()
202
  all_topics = base_model.topics_
 
 
203
  topic_plot = (
204
  base_model.visualize_document_datamap(
205
  docs=all_docs,
@@ -224,11 +226,13 @@ def generate_topics(dataset, config, split, column, plot_type):
224
  if plot_type == "DataMapPlot"
225
  else base_model.visualize_documents(
226
  docs=all_docs,
 
227
  reduced_embeddings=reduced_embeddings_array,
228
  custom_labels=True,
229
  title="",
230
  )
231
  )
 
232
  rows_processed += len(docs)
233
  progress = min(rows_processed / limit, 1.0)
234
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
@@ -403,7 +407,7 @@ def generate_topics(dataset, config, split, column, plot_type):
403
  del (
404
  base_model,
405
  all_topics,
406
- topic_info,
407
  topic_names_array,
408
  interactive_plot,
409
  )
 
15
  from huggingface_hub import HfApi, InferenceClient
16
  from sklearn.feature_extraction.text import CountVectorizer
17
  from sentence_transformers import SentenceTransformer
18
+ from torch import cuda
19
 
20
  from src.hub import create_space_with_content
21
  from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT
 
168
 
169
  try:
170
  while offset < limit:
171
+ logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
172
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
173
  if not docs:
174
  break
175
+ logging.info(f"Got {len(docs)} docs ✓")
 
 
 
 
176
  embeddings = calculate_embeddings(docs)
177
  new_model = fit_model(docs, embeddings, n_neighbors, n_components)
178
 
 
190
  logging.info(f"The following topics are newly found: {new_topics}")
191
  base_model = updated_model
192
 
193
+ logging.info("Reducing embeddings to 2D")
194
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
195
  reduced_embeddings_list.append(reduced_embeddings)
196
 
197
  all_docs.extend(docs)
198
  reduced_embeddings_array = np.vstack(reduced_embeddings_list)
199
+ logging.info("Reducing embeddings to 2D ✓")
200
 
201
  topics_info = base_model.get_topic_info()
202
  all_topics = base_model.topics_
203
+ logging.info(f"Preparing topics {plot_type} plot")
204
+
205
  topic_plot = (
206
  base_model.visualize_document_datamap(
207
  docs=all_docs,
 
226
  if plot_type == "DataMapPlot"
227
  else base_model.visualize_documents(
228
  docs=all_docs,
229
+ topics=all_topics,
230
  reduced_embeddings=reduced_embeddings_array,
231
  custom_labels=True,
232
  title="",
233
  )
234
  )
235
+ logging.info("Plot done ✓")
236
  rows_processed += len(docs)
237
  progress = min(rows_processed / limit, 1.0)
238
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
 
407
  del (
408
  base_model,
409
  all_topics,
410
+ topics_info,
411
  topic_names_array,
412
  interactive_plot,
413
  )