asoria HF staff commited on
Commit
fe421d1
1 Parent(s): e2d9a99

Adding Llama2 topics

Browse files
Files changed (2) hide show
  1. app.py +127 -51
  2. prompts.py +29 -0
app.py CHANGED
@@ -6,9 +6,24 @@ from gradio_huggingfacehub_search import HuggingfaceHubSearch
6
  from bertopic import BERTopic
7
  import pandas as pd
8
  import gradio as gr
9
- from bertopic.representation import KeyBERTInspired
 
 
 
 
10
  from umap import UMAP
11
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # from cuml.cluster import HDBSCAN
14
  # from cuml.manifold import UMAP
@@ -21,6 +36,60 @@ logging.basicConfig(
21
 
22
  session = requests.Session()
23
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def get_parquet_urls(dataset, config, split):
@@ -44,22 +113,27 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
44
 
45
  # @spaces.GPU
46
  def calculate_embeddings(docs):
47
- embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
48
- logging.info(f"Embeddings shape: {embeddings.shape}")
49
- return embeddings
50
 
51
 
52
  # @spaces.GPU
53
- def fit_model(base_model, sentence_model, representation_model, docs, embeddings):
54
  new_model = BERTopic(
55
  "english",
 
56
  embedding_model=sentence_model,
 
 
57
  representation_model=representation_model,
58
- min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
 
 
 
59
  )
60
  logging.info("Fitting new model")
61
  new_model.fit(docs, embeddings)
62
  logging.info("End fitting new model")
 
63
  if base_model is None:
64
  return new_model, new_model
65
 
@@ -68,6 +142,8 @@ def fit_model(base_model, sentence_model, representation_model, docs, embeddings
68
  new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
69
  logging.info("The following topics are newly found:")
70
  logging.info(f"{new_topics}\n")
 
 
71
  return updated_model, new_model
72
 
73
 
@@ -80,7 +156,6 @@ def generate_topics(dataset, config, split, column, nested_column):
80
  limit = 1_000
81
  chunk_size = 300
82
  offset = 0
83
- representation_model = KeyBERTInspired()
84
  base_model = None
85
  all_docs = []
86
  all_reduced_embeddings = np.empty((0, 2))
@@ -93,22 +168,25 @@ def generate_topics(dataset, config, split, column, nested_column):
93
  offset = offset + chunk_size
94
  if not docs or offset >= limit:
95
  break
96
- base_model, _ = fit_model(
97
- base_model, sentence_model, representation_model, docs, embeddings
98
- )
 
 
 
 
 
99
  reduced_embeddings = UMAP(
100
  n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
101
  ).fit_transform(embeddings)
102
- logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
103
 
104
  all_docs.extend(docs)
105
  all_reduced_embeddings = np.vstack((all_reduced_embeddings, reduced_embeddings))
106
- logging.info(f"Stacked embeddings shape: {all_reduced_embeddings.shape}")
107
  topics_info = base_model.get_topic_info()
108
  topic_plot = base_model.visualize_documents(
109
- all_docs, reduced_embeddings=all_reduced_embeddings
110
  )
111
-
112
  yield topics_info, topic_plot
113
 
114
  logging.info("Finished processing all data")
@@ -116,47 +194,45 @@ def generate_topics(dataset, config, split, column, nested_column):
116
 
117
 
118
  with gr.Blocks() as demo:
119
- gr.Markdown(
120
- """
121
- # 💠 Dataset Topic Discovery 🔭
122
- ## Select dataset and text column
123
- """
124
- )
125
- with gr.Row():
126
- with gr.Column(scale=3):
127
- dataset_name = HuggingfaceHubSearch(
128
- label="Hub Dataset ID",
129
- placeholder="Search for dataset id on Huggingface",
130
- search_type="dataset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
- subset_dropdown = gr.Dropdown(label="Subset", visible=False)
133
- split_dropdown = gr.Dropdown(label="Split", visible=False)
134
-
135
- with gr.Accordion("Dataset preview", open=False):
136
-
137
- @gr.render(inputs=[dataset_name, subset_dropdown, split_dropdown])
138
- def embed(name, subset, split):
139
- html_code = f"""
140
- <iframe
141
- src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}"
142
- frameborder="0"
143
- width="100%"
144
- height="600px"
145
- ></iframe>
146
- """
147
- return gr.HTML(value=html_code)
148
-
149
- with gr.Row():
150
- text_column_dropdown = gr.Dropdown(label="Text column name")
151
- nested_text_column_dropdown = gr.Dropdown(
152
- label="Nested text column name", visible=False
153
- )
154
 
155
- generate_button = gr.Button("Generate Notebook", variant="primary")
156
 
157
- gr.Markdown("## Topics info")
158
- topics_df = gr.DataFrame(interactive=False, visible=True)
159
  topics_plot = gr.Plot()
 
 
160
  generate_button.click(
161
  generate_topics,
162
  inputs=[
 
6
  from bertopic import BERTopic
7
  import pandas as pd
8
  import gradio as gr
9
+ from bertopic.representation import (
10
+ KeyBERTInspired,
11
+ MaximalMarginalRelevance,
12
+ TextGeneration,
13
+ )
14
  from umap import UMAP
15
  import numpy as np
16
+ from torch import cuda
17
+ from torch import bfloat16
18
+ from transformers import (
19
+ BitsAndBytesConfig,
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ pipeline,
23
+ )
24
+ from prompts import system_prompt, example_prompt, main_prompt
25
+ from umap import UMAP
26
+ from hdbscan import HDBSCAN
27
 
28
  # from cuml.cluster import HDBSCAN
29
  # from cuml.manifold import UMAP
 
36
 
37
  session = requests.Session()
38
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
39
+ keybert = KeyBERTInspired()
40
+ mmr = MaximalMarginalRelevance(diversity=0.3)
41
+
42
+
43
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
44
+ device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"
45
+ logging.info(device)
46
+
47
+ bnb_config = BitsAndBytesConfig(
48
+ load_in_4bit=True, # 4-bit quantization
49
+ bnb_4bit_quant_type="nf4", # Normalized float 4
50
+ bnb_4bit_use_double_quant=True, # Second quantization after the first
51
+ bnb_4bit_compute_dtype=bfloat16, # Computation type
52
+ )
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
55
+
56
+ # Llama 2 Model
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_id,
59
+ trust_remote_code=True,
60
+ quantization_config=bnb_config,
61
+ device_map="auto",
62
+ )
63
+
64
+ generator = pipeline(
65
+ model=model,
66
+ tokenizer=tokenizer,
67
+ task="text-generation",
68
+ temperature=0.1,
69
+ max_new_tokens=500,
70
+ repetition_penalty=1.1,
71
+ )
72
+ prompt = system_prompt + example_prompt + main_prompt
73
+
74
+ llama2 = TextGeneration(generator, prompt=prompt)
75
+ representation_model = {
76
+ # "KeyBERT": keybert,
77
+ "Llama2": llama2,
78
+ # "MMR": mmr,
79
+ }
80
+
81
+ # umap_model = UMAP(
82
+ # n_neighbors=15, n_components=5, min_dist=0.0, metric="cosine", random_state=42
83
+ # )
84
+ # hdbscan_model = HDBSCAN(
85
+ # min_cluster_size=150,
86
+ # metric="euclidean",
87
+ # cluster_selection_method="eom",
88
+ # prediction_data=True,
89
+ # )
90
+ # reduce_umap_model = UMAP(
91
+ # n_neighbors=15, n_components=2, min_dist=0.0, metric="cosine", random_state=42
92
+ # )
93
 
94
 
95
  def get_parquet_urls(dataset, config, split):
 
113
 
114
  # @spaces.GPU
115
  def calculate_embeddings(docs):
116
+ return sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
 
 
117
 
118
 
119
  # @spaces.GPU
120
+ def fit_model(base_model, docs, embeddings):
121
  new_model = BERTopic(
122
  "english",
123
+ # Sub-models
124
  embedding_model=sentence_model,
125
+ # umap_model=umap_model,
126
+ # hdbscan_model=hdbscan_model,
127
  representation_model=representation_model,
128
+ # Hyperparameters
129
+ top_n_words=10,
130
+ verbose=True,
131
+ min_topic_size=15,
132
  )
133
  logging.info("Fitting new model")
134
  new_model.fit(docs, embeddings)
135
  logging.info("End fitting new model")
136
+
137
  if base_model is None:
138
  return new_model, new_model
139
 
 
142
  new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
143
  logging.info("The following topics are newly found:")
144
  logging.info(f"{new_topics}\n")
145
+ # updated_model.set_topic_labels(updated_model.topic_labels_)
146
+
147
  return updated_model, new_model
148
 
149
 
 
156
  limit = 1_000
157
  chunk_size = 300
158
  offset = 0
 
159
  base_model = None
160
  all_docs = []
161
  all_reduced_embeddings = np.empty((0, 2))
 
168
  offset = offset + chunk_size
169
  if not docs or offset >= limit:
170
  break
171
+ base_model, _ = fit_model(base_model, docs, embeddings)
172
+ llama2_labels = [
173
+ label[0][0].split("\n")[0]
174
+ for label in base_model.get_topics(full=True)["Llama2"].values()
175
+ ]
176
+ logging.info(f"Topics: {llama2_labels}")
177
+ base_model.set_topic_labels(llama2_labels)
178
+
179
  reduced_embeddings = UMAP(
180
  n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
181
  ).fit_transform(embeddings)
 
182
 
183
  all_docs.extend(docs)
184
  all_reduced_embeddings = np.vstack((all_reduced_embeddings, reduced_embeddings))
 
185
  topics_info = base_model.get_topic_info()
186
  topic_plot = base_model.visualize_documents(
187
+ all_docs, reduced_embeddings=all_reduced_embeddings, custom_labels=True
188
  )
189
+ logging.info(f"Topics for merged model: {base_model.topic_labels_}")
190
  yield topics_info, topic_plot
191
 
192
  logging.info("Finished processing all data")
 
194
 
195
 
196
  with gr.Blocks() as demo:
197
+ gr.Markdown("# 💠 Dataset Topic Discovery 🔭")
198
+ gr.Markdown("## Select dataset and text column")
199
+ with gr.Accordion("Data details", open=True):
200
+ with gr.Row():
201
+ with gr.Column(scale=3):
202
+ dataset_name = HuggingfaceHubSearch(
203
+ label="Hub Dataset ID",
204
+ placeholder="Search for dataset id on Huggingface",
205
+ search_type="dataset",
206
+ )
207
+ subset_dropdown = gr.Dropdown(label="Subset", visible=False)
208
+ split_dropdown = gr.Dropdown(label="Split", visible=False)
209
+
210
+ with gr.Accordion("Dataset preview", open=False):
211
+
212
+ @gr.render(inputs=[dataset_name, subset_dropdown, split_dropdown])
213
+ def embed(name, subset, split):
214
+ html_code = f"""
215
+ <iframe
216
+ src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}"
217
+ frameborder="0"
218
+ width="100%"
219
+ height="600px"
220
+ ></iframe>
221
+ """
222
+ return gr.HTML(value=html_code)
223
+
224
+ with gr.Row():
225
+ text_column_dropdown = gr.Dropdown(label="Text column name")
226
+ nested_text_column_dropdown = gr.Dropdown(
227
+ label="Nested text column name", visible=False
228
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ generate_button = gr.Button("Generate Notebook", variant="primary")
231
 
232
+ gr.Markdown("## Datamap")
 
233
  topics_plot = gr.Plot()
234
+ with gr.Accordion("Topics Info", open=False):
235
+ topics_df = gr.DataFrame(interactive=False, visible=True)
236
  generate_button.click(
237
  generate_topics,
238
  inputs=[
prompts.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_prompt = """
2
+ <s>[INST] <<SYS>>
3
+ You are a helpful, respectful and honest assistant for labeling topics.
4
+ <</SYS>>
5
+ """
6
+
7
+ example_prompt = """
8
+ I have a topic that contains the following documents:
9
+ - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
10
+ - Meat, but especially beef, is the word food in terms of emissions.
11
+ - Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
12
+
13
+ The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
14
+
15
+ Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
16
+
17
+ [/INST] Environmental impacts of eating meat
18
+ """
19
+
20
+ main_prompt = """
21
+ [INST]
22
+ I have a topic that contains the following documents:
23
+ [DOCUMENTS]
24
+
25
+ The topic is described by the following keywords: '[KEYWORDS]'.
26
+
27
+ Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
28
+ [/INST]
29
+ """