Spaces:
Sleeping
Sleeping
# import spaces | |
import gradio as gr | |
import logging | |
import os | |
import datamapplot | |
import numpy as np | |
from dotenv import load_dotenv | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from bertopic import BERTopic | |
from bertopic.representation import KeyBERTInspired | |
from huggingface_hub import HfApi, InferenceClient | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sentence_transformers import SentenceTransformer | |
from torch import cuda | |
from src.hub import create_space_with_content | |
from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT | |
from src.viewer_api import ( | |
get_split_rows, | |
get_parquet_urls, | |
get_docs_from_parquet, | |
get_info, | |
) | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" | |
MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000")) | |
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000")) | |
DATASETS_TOPICS_ORGANIZATION = os.getenv( | |
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics" | |
) | |
USE_CUML = int(os.getenv("USE_CUML", "1")) | |
USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1")) | |
# Use cuml lib only if configured | |
if USE_CUML: | |
from cuml.manifold import UMAP | |
from cuml.cluster import HDBSCAN | |
else: | |
from umap import UMAP | |
from hdbscan import HDBSCAN | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
api = HfApi(token=HF_TOKEN) | |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# Representation model | |
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
representation_model = KeyBERTInspired() | |
vectorizer_model = CountVectorizer(stop_words="english") | |
inference_client = InferenceClient(model_id) | |
def calculate_embeddings(docs): | |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=32) | |
def calculate_n_neighbors_and_components(n_rows): | |
n_neighbors = min(max(n_rows // 20, 15), 100) | |
n_components = 10 if n_rows > 1000 else 5 # Higher components for larger datasets | |
return n_neighbors, n_components | |
def fit_model(docs, embeddings, n_neighbors, n_components): | |
umap_model = UMAP( | |
n_neighbors=n_neighbors, | |
n_components=n_components, | |
min_dist=0.0, | |
metric="cosine", | |
random_state=42, | |
) | |
hdbscan_model = HDBSCAN( | |
min_cluster_size=max( | |
5, n_neighbors // 2 | |
), # Reducing min_cluster_size for fewer outliers | |
metric="euclidean", | |
cluster_selection_method="eom", | |
prediction_data=True, | |
) | |
new_model = BERTopic( | |
language="english", | |
# Sub-models | |
embedding_model=sentence_model, # Step 1 - Extract embeddings | |
umap_model=umap_model, # Step 2 - UMAP model | |
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings | |
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics | |
representation_model=representation_model, # Step 5 - Label topics | |
# Hyperparameters | |
top_n_words=10, | |
verbose=True, | |
min_topic_size=n_neighbors, # Coherent with n_neighbors? | |
) | |
logging.info("Fitting new model") | |
new_model.fit(docs, embeddings) | |
logging.info("End fitting new model") | |
return new_model | |
# @spaces.GPU(duration=60 * 5) | |
def generate_topics(dataset, config, split, column, plot_type): | |
logging.info( | |
f"Generating topics for {dataset=} {config=} {split=} {column=} {plot_type=}" | |
) | |
parquet_urls = get_parquet_urls(dataset, config, split) | |
split_rows = get_split_rows(dataset, config, split) | |
if split_rows is None or split_rows == 0: | |
return ( | |
gr.Accordion(open=True), | |
gr.DataFrame(value=[], interactive=False, visible=True), | |
gr.Plot(value=None, visible=True), | |
gr.Label( | |
{"❌ Error: No data found for the selected dataset": 0.0}, visible=True | |
), | |
"", | |
) | |
logging.info(f"Split number of rows: {split_rows}") | |
limit = min(split_rows, MAX_ROWS) | |
n_neighbors, n_components = calculate_n_neighbors_and_components(limit) | |
reduce_umap_model = UMAP( | |
n_neighbors=n_neighbors, | |
n_components=2, # For visualization, keeping it for 2D | |
min_dist=0.0, | |
metric="cosine", | |
random_state=42, | |
) | |
offset = 0 | |
rows_processed = 0 | |
base_model = None | |
all_docs = [] | |
reduced_embeddings_list = [] | |
topics_info, topic_plot = None, None | |
full_processing = split_rows <= MAX_ROWS | |
message = ( | |
f"Processing topics for full dataset: 0 of ({split_rows} rows)" | |
if full_processing | |
else f"Processing topics for partial dataset 0 of ({limit} rows)" | |
) | |
sub_title = ( | |
f"Data map for the entire dataset ({limit} rows) using the column '{column}'" | |
if full_processing | |
else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'" | |
) | |
yield ( | |
gr.Accordion(open=False), | |
gr.DataFrame(value=[], interactive=False, visible=True), | |
gr.Plot(value=None, visible=True), | |
gr.Label({"⏳ " + message: 0.0}, visible=True), | |
"", | |
) | |
while offset < limit: | |
logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}") | |
docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE) | |
if not docs: | |
break | |
logging.info(f"Got {len(docs)} docs ✓") | |
embeddings = calculate_embeddings(docs) | |
new_model = fit_model(docs, embeddings, n_neighbors, n_components) | |
if base_model is None: | |
base_model = new_model | |
logging.info( | |
f"The following topics are newly found: {base_model.topic_labels_}" | |
) | |
else: | |
updated_model = BERTopic.merge_models([base_model, new_model]) | |
nr_new_topics = len(set(updated_model.topics_)) - len( | |
set(base_model.topics_) | |
) | |
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] | |
logging.info(f"The following topics are newly found: {new_topics}") | |
base_model = updated_model | |
logging.info("Reducing embeddings to 2D") | |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings) | |
reduced_embeddings_list.append(reduced_embeddings) | |
logging.info("Reducing embeddings to 2D ✓") | |
all_docs.extend(docs) | |
reduced_embeddings_array = np.vstack(reduced_embeddings_list) | |
topics_info = base_model.get_topic_info() | |
all_topics = base_model.topics_ | |
logging.info(f"Preparing topics {plot_type} plot") | |
topic_plot = ( | |
base_model.visualize_document_datamap( | |
docs=all_docs, | |
topics=all_topics, | |
reduced_embeddings=reduced_embeddings_array, | |
title="", | |
sub_title=sub_title, | |
width=800, | |
height=700, | |
arrowprops={ | |
"arrowstyle": "wedge,tail_width=0.5", | |
"connectionstyle": "arc3,rad=0.05", | |
"linewidth": 0, | |
"fc": "#33333377", | |
}, | |
dynamic_label_size=True, | |
# label_wrap_width=12, | |
label_over_points=True, | |
max_font_size=36, | |
min_font_size=4, | |
) | |
if plot_type == "DataMapPlot" | |
else base_model.visualize_documents( | |
docs=all_docs, | |
topics=all_topics, | |
reduced_embeddings=reduced_embeddings_array, | |
title="", | |
) | |
) | |
logging.info("Plot done ✓") | |
rows_processed += len(docs) | |
progress = min(rows_processed / limit, 1.0) | |
logging.info(f"Progress: {progress} % - {rows_processed} of {limit}") | |
message = ( | |
f"Processing topics for full dataset: {rows_processed} of {limit}" | |
if full_processing | |
else f"Processing topics for partial dataset: {rows_processed} of {limit} rows" | |
) | |
yield ( | |
gr.Accordion(open=False), | |
topics_info, | |
topic_plot, | |
gr.Label({"⏳ " + message: progress}, visible=True), | |
"", | |
) | |
offset += CHUNK_SIZE | |
del docs, embeddings, new_model, reduced_embeddings | |
logging.info("Finished processing all data") | |
yield ( | |
gr.Accordion(open=False), | |
topics_info, | |
topic_plot, | |
gr.Label( | |
{ | |
"✅ " + message: 1.0, | |
f"⏳ Generating topic names with {model_id}": 0.0, | |
}, | |
visible=True, | |
), | |
"", | |
) | |
all_topics = base_model.topics_ | |
topics_info = base_model.get_topic_info() | |
new_topics_by_text_generation = {} | |
for _, row in topics_info.iterrows(): | |
logging.info( | |
f"Processing topic: {row['Topic']} - Representation: {row['Representation']}" | |
) | |
prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}" | |
prompt_messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful, respectful and honest assistant for labeling topics.", | |
}, | |
{"role": "user", "content": prompt}, | |
] | |
output = inference_client.chat_completion( | |
messages=prompt_messages, | |
stream=False, | |
max_tokens=500, | |
top_p=0.8, | |
seed=42, | |
) | |
inference_response = output.choices[0].message.content | |
logging.info("Inference response:") | |
logging.info(inference_response) | |
new_topics_by_text_generation[row["Topic"]] = inference_response.replace( | |
"Topic=", "" | |
).strip() | |
base_model.set_topic_labels(new_topics_by_text_generation) | |
topics_info = base_model.get_topic_info() | |
topic_plot = ( | |
base_model.visualize_document_datamap( | |
docs=all_docs, | |
topics=all_topics, | |
custom_labels=True, | |
reduced_embeddings=reduced_embeddings_array, | |
title="", | |
sub_title=sub_title, | |
width=800, | |
height=700, | |
arrowprops={ | |
"arrowstyle": "wedge,tail_width=0.5", | |
"connectionstyle": "arc3,rad=0.05", | |
"linewidth": 0, | |
"fc": "#33333377", | |
}, | |
dynamic_label_size=True, | |
# label_wrap_width=12, | |
label_over_points=True, | |
max_font_size=36, | |
min_font_size=4, | |
) | |
if plot_type == "DataMapPlot" | |
else base_model.visualize_documents( | |
docs=all_docs, | |
reduced_embeddings=reduced_embeddings_array, | |
custom_labels=True, | |
title="", | |
) | |
) | |
dataset_clear_name = dataset.replace("/", "-") | |
plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png" | |
if plot_type == "DataMapPlot": | |
topic_plot.savefig(plot_png, format="png", dpi=300) | |
else: | |
topic_plot.write_image(plot_png) | |
custom_labels = base_model.custom_labels_ | |
topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics] | |
yield ( | |
gr.Accordion(open=False), | |
topics_info, | |
topic_plot, | |
gr.Label( | |
{ | |
"✅ " + message: 1.0, | |
f"✅ Generating topic names with {model_id}": 1.0, | |
"⏳ Creating Interactive Space": 0.0, | |
}, | |
visible=True, | |
), | |
"", | |
) | |
interactive_plot = datamapplot.create_interactive_plot( | |
reduced_embeddings_array, | |
topic_names_array, | |
hover_text=all_docs, | |
title=dataset, | |
sub_title=sub_title.replace( | |
"dataset", | |
f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>", | |
), | |
enable_search=True, | |
# TODO: Export data to .arrow and also serve it | |
inline_data=True, | |
# offline_data_prefix=dataset_clear_name, | |
initial_zoom_fraction=0.9, | |
cluster_boundary_polygons=True | |
) | |
html_content = str(interactive_plot) | |
html_file_path = f"{dataset_clear_name}.html" | |
with open(html_file_path, "w", encoding="utf-8") as html_file: | |
html_file.write(html_content) | |
repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}" | |
space_id = create_space_with_content( | |
api=api, | |
repo_id=repo_id, | |
dataset_id=dataset, | |
html_file_path=html_file_path, | |
plot_file_path=plot_png, | |
space_card=SPACE_REPO_CARD_CONTENT, | |
token=HF_TOKEN, | |
) | |
space_link = f"https://huggingface.co/spaces/{space_id}" | |
yield ( | |
gr.Accordion(open=False), | |
topics_info, | |
topic_plot, | |
gr.Label( | |
{ | |
"✅ " + message: 1.0, | |
f"✅ Generating topic names with {model_id}": 1.0, | |
"✅ Creating Interactive Space": 1.0, | |
}, | |
visible=True, | |
), | |
f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})", | |
) | |
del reduce_umap_model, all_docs, reduced_embeddings_list | |
del ( | |
base_model, | |
all_topics, | |
topics_info, | |
topic_plot, | |
topic_names_array, | |
interactive_plot, | |
) | |
cuda.empty_cache() | |
with gr.Blocks() as demo: | |
gr.HTML("<h1 style='text-align: center;'>💠 Dataset Topic Discovery 🔭</h1>") | |
gr.HTML( | |
"<h3 style='text-align: center;'>Select a dataset and text column for topic modeling</h3>" | |
) | |
gr.HTML( | |
"<p style='text-align: center; color:orange;'>⚠ This space is in progress, and we're actively working on it, so you might find some bugs! Please report any issues you have in the Community tab to help us make it better for all.</p>" | |
) | |
data_details_accordion = gr.Accordion("Data details", open=True) | |
with data_details_accordion: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
dataset_name = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
) | |
subset_dropdown = gr.Dropdown(label="Subset", visible=False) | |
split_dropdown = gr.Dropdown(label="Split", visible=False) | |
with gr.Accordion("Dataset preview", open=False): | |
def embed(name, subset, split): | |
html_code = f""" | |
<iframe | |
src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}" | |
frameborder="0" | |
width="100%" | |
height="600px" | |
></iframe> | |
""" | |
return gr.HTML(value=html_code) | |
with gr.Row(): | |
text_column_dropdown = gr.Dropdown(label="Text column name") | |
plot_type_radio = gr.Radio( | |
["DataMapPlot", "Plotly"], | |
value="DataMapPlot", | |
label="Choose the plot type", | |
interactive=True, | |
) | |
generate_button = gr.Button("Generate Topics", variant="primary") | |
gr.Markdown("## Data map") | |
full_topics_generation_label = gr.Label(visible=False, show_label=False) | |
open_space_label = gr.Markdown() | |
topics_plot = gr.Plot() | |
with gr.Accordion("Topics Info", open=False): | |
topics_df = gr.DataFrame(interactive=False, visible=True) | |
gr.HTML( | |
f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>" | |
) | |
gr.Markdown( | |
"_Powered by [bertopic](https://maartengr.github.io/BERTopic/index.html) [datamapplot](https://datamapplot.readthedocs.io/en/latest/) and [duckdb](https://duckdb.org/)_" | |
) | |
generate_button.click( | |
generate_topics, | |
inputs=[ | |
dataset_name, | |
subset_dropdown, | |
split_dropdown, | |
text_column_dropdown, | |
plot_type_radio, | |
], | |
outputs=[ | |
data_details_accordion, | |
topics_df, | |
topics_plot, | |
full_topics_generation_label, | |
open_space_label, | |
], | |
) | |
def _resolve_dataset_selection( | |
dataset: str, default_subset: str, default_split: str, text_feature | |
): | |
if "/" not in dataset.strip().strip("/"): | |
return { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
text_column_dropdown: gr.Dropdown(label="Text column name"), | |
} | |
try: | |
info_resp = get_info(dataset) | |
except Exception: | |
return { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
text_column_dropdown: gr.Dropdown(label="Text column name"), | |
} | |
subsets: list[str] = list(info_resp) | |
subset = default_subset if default_subset in subsets else subsets[0] | |
splits: list[str] = list(info_resp[subset]["splits"]) | |
split = default_split if default_split in splits else splits[0] | |
features = info_resp[subset]["features"] | |
def _is_string_feature(feature): | |
return isinstance(feature, dict) and feature.get("dtype") == "string" | |
text_features = [ | |
feature_name | |
for feature_name, feature in features.items() | |
if _is_string_feature(feature) | |
] | |
if not text_feature: | |
return { | |
subset_dropdown: gr.Dropdown( | |
value=subset, choices=subsets, visible=len(subsets) > 1 | |
), | |
split_dropdown: gr.Dropdown( | |
value=split, choices=splits, visible=len(splits) > 1 | |
), | |
text_column_dropdown: gr.Dropdown( | |
choices=text_features, | |
label="Text column name", | |
), | |
} | |
return { | |
subset_dropdown: gr.Dropdown( | |
value=subset, choices=subsets, visible=len(subsets) > 1 | |
), | |
split_dropdown: gr.Dropdown( | |
value=split, choices=splits, visible=len(splits) > 1 | |
), | |
text_column_dropdown: gr.Dropdown( | |
choices=text_features, label="Text column name" | |
), | |
} | |
def show_input_from_subset_dropdown(dataset: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset="default", default_split="train", text_feature=None | |
) | |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset=subset, default_split="train", text_feature=None | |
) | |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset=subset, default_split=split, text_feature=None | |
) | |
def show_input_from_text_column_dropdown( | |
dataset: str, subset: str, split: str, text_column | |
) -> dict: | |
return _resolve_dataset_selection( | |
dataset, | |
default_subset=subset, | |
default_split=split, | |
text_feature=text_column, | |
) | |
demo.launch() | |