asoria HF staff commited on
Commit
657db0b
1 Parent(s): 9ccf916

Batched fit

Browse files
Files changed (2) hide show
  1. app.py +301 -4
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,304 @@
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ import duckdb
4
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
5
+ from bertopic import BERTopic
6
+ import pandas as pd
7
  import gradio as gr
8
+ from bertopic.representation import KeyBERTInspired
9
 
 
 
10
 
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
13
+ )
14
+
15
+
16
+ session = requests.Session()
17
+
18
+
19
+ def get_parquet_urls(dataset, config, split):
20
+ parquet_files = session.get(
21
+ f"https://datasets-server.huggingface.co/parquet?dataset={dataset}&config={config}&split={split}",
22
+ timeout=20,
23
+ ).json()
24
+ if "error" in parquet_files:
25
+ raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
26
+ parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
27
+ logging.info(f"Parquet files: {parquet_urls}")
28
+ return ",".join(f"'{url}'" for url in parquet_urls)
29
+
30
+
31
+ def get_docs_from_parquet(parquet_urls, column, offset, limit):
32
+ SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
33
+ df = duckdb.sql(SQL_QUERY).to_df()
34
+ logging.debug(f"Dataframe: {df.head(5)}")
35
+ return df[column].tolist()
36
+
37
+
38
+ def generate_topics(dataset, config, split, column, nested_column, progress):
39
+ logging.info(
40
+ f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
41
+ )
42
+
43
+ parquet_urls = get_parquet_urls(dataset, config, split)
44
+ limit = 1_000
45
+ chunk_size = 300
46
+ offset = 0
47
+ representation_model = KeyBERTInspired()
48
+
49
+ docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
50
+
51
+ base_model = BERTopic(
52
+ representation_model=representation_model, min_topic_size=15
53
+ ).fit(docs)
54
+
55
+ yield base_model.get_topic_info(), base_model.visualize_topics()
56
+
57
+ while True:
58
+ offset = offset + chunk_size
59
+ if not docs or offset >= limit:
60
+ break
61
+
62
+ docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
63
+ logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
64
+ logging.info(docs[:5])
65
+
66
+ new_model = BERTopic(
67
+ "english", representation_model=representation_model, min_topic_size=15
68
+ ).fit(docs)
69
+ updated_model = BERTopic.merge_models([base_model, new_model])
70
+ nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_))
71
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
72
+ logging.info("The following topics are newly found:")
73
+ logging.info(f"{new_topics}\n")
74
+
75
+ # Update the base model
76
+ base_model = updated_model
77
+
78
+ logging.info(base_model.get_topic_info())
79
+ yield base_model.get_topic_info(), base_model.visualize_topics()
80
+
81
+ return base_model.get_topic_info(), base_model.visualize_topics()
82
+
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown(
86
+ """
87
+ # 💠 Dataset Topic Discovery 🔭
88
+ ## Select dataset and text column
89
+ """
90
+ )
91
+ with gr.Row():
92
+ with gr.Column(scale=3):
93
+ dataset_name = HuggingfaceHubSearch(
94
+ label="Hub Dataset ID",
95
+ placeholder="Search for dataset id on Huggingface",
96
+ search_type="dataset",
97
+ )
98
+ subset_dropdown = gr.Dropdown(label="Subset", visible=False)
99
+ split_dropdown = gr.Dropdown(label="Split", visible=False)
100
+
101
+ with gr.Accordion("Dataset preview", open=False):
102
+
103
+ @gr.render(inputs=[dataset_name, subset_dropdown, split_dropdown])
104
+ def embed(name, subset, split):
105
+ html_code = f"""
106
+ <iframe
107
+ src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}"
108
+ frameborder="0"
109
+ width="100%"
110
+ height="600px"
111
+ ></iframe>
112
+ """
113
+ return gr.HTML(value=html_code)
114
+
115
+ with gr.Row():
116
+ text_column_dropdown = gr.Dropdown(label="Text column name")
117
+ nested_text_column_dropdown = gr.Dropdown(
118
+ label="Nested text column name", visible=False
119
+ )
120
+
121
+ generate_button = gr.Button("Generate Notebook", variant="primary")
122
+
123
+ gr.Markdown("## Topics info")
124
+ progress = gr.Progress()
125
+ topics_df = gr.DataFrame(interactive=False, visible=True)
126
+ topics_plot = gr.Plot()
127
+ generate_button.click(
128
+ generate_topics,
129
+ inputs=[
130
+ dataset_name,
131
+ subset_dropdown,
132
+ split_dropdown,
133
+ text_column_dropdown,
134
+ nested_text_column_dropdown,
135
+ progress,
136
+ ],
137
+ outputs=[topics_df, topics_plot],
138
+ )
139
+
140
+ # TODO: choose num_rows, random, or offset -> By default limit max to 1176 rows
141
+ # -> From the article, it could be in GPU 1176/sec
142
+
143
+ def _resolve_dataset_selection(
144
+ dataset: str, default_subset: str, default_split: str, text_feature
145
+ ):
146
+ if "/" not in dataset.strip().strip("/"):
147
+ return {
148
+ subset_dropdown: gr.Dropdown(visible=False),
149
+ split_dropdown: gr.Dropdown(visible=False),
150
+ text_column_dropdown: gr.Dropdown(label="Text column name"),
151
+ nested_text_column_dropdown: gr.Dropdown(visible=False),
152
+ }
153
+ info_resp = session.get(
154
+ f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=20
155
+ ).json()
156
+ if "error" in info_resp:
157
+ return {
158
+ subset_dropdown: gr.Dropdown(visible=False),
159
+ split_dropdown: gr.Dropdown(visible=False),
160
+ text_column_dropdown: gr.Dropdown(label="Text column name"),
161
+ nested_text_column_dropdown: gr.Dropdown(visible=False),
162
+ }
163
+ subsets: list[str] = list(info_resp["dataset_info"])
164
+ subset = default_subset if default_subset in subsets else subsets[0]
165
+ splits: list[str] = list(info_resp["dataset_info"][subset]["splits"])
166
+ split = default_split if default_split in splits else splits[0]
167
+ features = info_resp["dataset_info"][subset]["features"]
168
+
169
+ def _is_string_feature(feature):
170
+ return isinstance(feature, dict) and feature.get("dtype") == "string"
171
+
172
+ text_features = [
173
+ feature_name
174
+ for feature_name, feature in features.items()
175
+ if _is_string_feature(feature)
176
+ ]
177
+ nested_features = [
178
+ feature_name
179
+ for feature_name, feature in features.items()
180
+ if isinstance(feature, dict)
181
+ and isinstance(next(iter(feature.values())), dict)
182
+ ]
183
+ nested_text_features = [
184
+ feature_name
185
+ for feature_name in nested_features
186
+ if any(
187
+ _is_string_feature(nested_feature)
188
+ for nested_feature in features[feature_name].values()
189
+ )
190
+ ]
191
+ if not text_feature:
192
+ return {
193
+ subset_dropdown: gr.Dropdown(
194
+ value=subset, choices=subsets, visible=len(subsets) > 1
195
+ ),
196
+ split_dropdown: gr.Dropdown(
197
+ value=split, choices=splits, visible=len(splits) > 1
198
+ ),
199
+ text_column_dropdown: gr.Dropdown(
200
+ choices=text_features + nested_text_features,
201
+ label="Text column name",
202
+ ),
203
+ nested_text_column_dropdown: gr.Dropdown(visible=False),
204
+ }
205
+ if text_feature in nested_text_features:
206
+ nested_keys = [
207
+ feature_name
208
+ for feature_name, feature in features[text_feature].items()
209
+ if _is_string_feature(feature)
210
+ ]
211
+ return {
212
+ subset_dropdown: gr.Dropdown(
213
+ value=subset, choices=subsets, visible=len(subsets) > 1
214
+ ),
215
+ split_dropdown: gr.Dropdown(
216
+ value=split, choices=splits, visible=len(splits) > 1
217
+ ),
218
+ text_column_dropdown: gr.Dropdown(
219
+ choices=text_features + nested_text_features,
220
+ label="Text column name",
221
+ ),
222
+ nested_text_column_dropdown: gr.Dropdown(
223
+ value=nested_keys[0],
224
+ choices=nested_keys,
225
+ label="Nested text column name",
226
+ visible=True,
227
+ ),
228
+ }
229
+ return {
230
+ subset_dropdown: gr.Dropdown(
231
+ value=subset, choices=subsets, visible=len(subsets) > 1
232
+ ),
233
+ split_dropdown: gr.Dropdown(
234
+ value=split, choices=splits, visible=len(splits) > 1
235
+ ),
236
+ text_column_dropdown: gr.Dropdown(
237
+ choices=text_features + nested_text_features, label="Text column name"
238
+ ),
239
+ nested_text_column_dropdown: gr.Dropdown(visible=False),
240
+ }
241
+
242
+ @dataset_name.change(
243
+ inputs=[dataset_name],
244
+ outputs=[
245
+ subset_dropdown,
246
+ split_dropdown,
247
+ text_column_dropdown,
248
+ nested_text_column_dropdown,
249
+ ],
250
+ )
251
+ def show_input_from_subset_dropdown(dataset: str) -> dict:
252
+ return _resolve_dataset_selection(
253
+ dataset, default_subset="default", default_split="train", text_feature=None
254
+ )
255
+
256
+ @subset_dropdown.change(
257
+ inputs=[dataset_name, subset_dropdown],
258
+ outputs=[
259
+ subset_dropdown,
260
+ split_dropdown,
261
+ text_column_dropdown,
262
+ nested_text_column_dropdown,
263
+ ],
264
+ )
265
+ def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
266
+ return _resolve_dataset_selection(
267
+ dataset, default_subset=subset, default_split="train", text_feature=None
268
+ )
269
+
270
+ @split_dropdown.change(
271
+ inputs=[dataset_name, subset_dropdown, split_dropdown],
272
+ outputs=[
273
+ subset_dropdown,
274
+ split_dropdown,
275
+ text_column_dropdown,
276
+ nested_text_column_dropdown,
277
+ ],
278
+ )
279
+ def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
280
+ return _resolve_dataset_selection(
281
+ dataset, default_subset=subset, default_split=split, text_feature=None
282
+ )
283
+
284
+ @text_column_dropdown.change(
285
+ inputs=[dataset_name, subset_dropdown, split_dropdown, text_column_dropdown],
286
+ outputs=[
287
+ subset_dropdown,
288
+ split_dropdown,
289
+ text_column_dropdown,
290
+ nested_text_column_dropdown,
291
+ ],
292
+ )
293
+ def show_input_from_text_column_dropdown(
294
+ dataset: str, subset: str, split: str, text_column
295
+ ) -> dict:
296
+ return _resolve_dataset_selection(
297
+ dataset,
298
+ default_subset=subset,
299
+ default_split=split,
300
+ text_feature=text_column,
301
+ )
302
+
303
+
304
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio_huggingfacehub_search==0.0.7
2
+ duckdb
3
+ umap-learn
4
+ sentence-transformers
5
+ datamapplot
6
+ bertopic
7
+ pandas