nbroad commited on
Commit
aca33e8
·
1 Parent(s): ad93c5f

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +0 -13
  2. app.py +197 -0
  3. requirements.txt +8 -0
  4. utils.py +452 -0
README.md CHANGED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Bulk Embeddings
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.36.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from utils import load_hf_dataset, get_model_and_tokenizer, batch_embed
4
+
5
+
6
+ # TODO: add instructor models
7
+ # "hkunlp/instructor-xl",
8
+ # "hkunlp/instructor-large",
9
+ # "hkunlp/instructor-base",
10
+
11
+ # model ids and hidden sizes
12
+ models_and_hidden_sizes = [
13
+ ("intfloat/e5-small-v2", 384),
14
+ ("intfloat/e5-base-v2", 768),
15
+ ("intfloat/e5-large-v2", 1024),
16
+ ("intfloat/multilingual-e5-small", 384),
17
+ ("intfloat/multilingual-e5-base", 768),
18
+ ("intfloat/multilingual-e5-large", 1024),
19
+ ("sentence-transformers/all-MiniLM-L6-v2", 384),
20
+ ("sentence-transformers/all-MiniLM-L12-v2", 384),
21
+ ("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 384),
22
+ ]
23
+
24
+ model_options = [
25
+ f"{model_name} (hidden_size = {hidden_size})"
26
+ for model_name, hidden_size in models_and_hidden_sizes
27
+ ]
28
+
29
+
30
+ opt2desc = {
31
+ "O2": "Most precise, slowest (O2: basic and extended general optimizations, transformers-specific fusions)",
32
+ "O3": "Less precise, faster (O3: O2 + gelu approx)",
33
+ "O4": "Least precise, fastest (O4: O3 + fp16/bf16)",
34
+ }
35
+
36
+ desc2opt = {v: k for k, v in opt2desc.items()}
37
+
38
+
39
+ optimization_options = list(opt2desc.values())
40
+
41
+
42
+ def run(
43
+ ds_name,
44
+ ds_config,
45
+ column_name,
46
+ ds_split,
47
+ model_choice,
48
+ opt_desc,
49
+ new_dataset_id,
50
+ num2skip,
51
+ num2embed,
52
+ progress=gr.Progress(),
53
+ ):
54
+ if progress is not None:
55
+ progress(0.5, "Loading dataset...")
56
+ ds = load_hf_dataset(ds_name, ds_config, ds_split)
57
+
58
+ opt_level = desc2opt[opt_desc]
59
+
60
+ model_name = model_choice.split()[0]
61
+
62
+ if progress is not None:
63
+ progress(0.2, "Downloading model and tokenizer...")
64
+ model, tokenizer = get_model_and_tokenizer(model_name, opt_level, progress)
65
+
66
+ doc_count, seconds_taken = batch_embed(
67
+ ds,
68
+ model,
69
+ tokenizer,
70
+ model_name=model_name,
71
+ column_name=column_name,
72
+ new_dataset_id=new_dataset_id,
73
+ opt_level=opt_level,
74
+ num2skip=num2skip,
75
+ num2embed=num2embed,
76
+ progress=progress,
77
+ )
78
+
79
+ return f"Embedded {doc_count} docs in {seconds_taken/60:.2f} minutes ({doc_count/seconds_taken:.1f} docs/sec)"
80
+
81
+
82
+ with gr.Blocks(title="Bulk embeddings") as demo:
83
+ gr.Markdown(
84
+ """
85
+ This Space allows you to embed a large dataset easily. For instance, this can easily create vectors for Wikipedia \
86
+ articles -- taking about __ hours and costing approximately $__.
87
+
88
+
89
+ This utilizes state-of-the-art open-source embedding models, \
90
+ and optimizes them for inference using Hugging Face [optimum](https://github.com/huggingface/optimum). There are various \
91
+ levels of optimizations that can be applied - the quality of the embeddings will degrade as the optimizations increase.
92
+
93
+ Currently available options: O2/O3/O4 on T4/A10 GPUs using onnx runtime.
94
+
95
+ Future options:
96
+ - OpenVino for CPU inference
97
+ - TensorRT for GPU inference
98
+ - Quantized models
99
+ - Instructor models
100
+ - Text splitting options
101
+ - More control about which rows to embed (skip some, stop early)
102
+ - Dynamic padding
103
+
104
+ ## Steps
105
+
106
+ 1. Upload the dataset to the Hugging Face Hub.
107
+ 2. Enter dataset details into the form below.
108
+ 3. Choose a model. These are taken from the top of the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
109
+ 4. Enter optimization level. See [here](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration) for details.
110
+ 5. Choose a name for the new dataset.
111
+ 6. Hit run!
112
+
113
+
114
+ ### Note:
115
+
116
+ If you have short documents, O3 will be faster than O4. If you have long documents, O4 will be faster than O3. \
117
+ O4 requires the tokenized documents to be padded to max length.
118
+
119
+ """
120
+ )
121
+
122
+ with gr.Row():
123
+ ds_name = gr.Textbox(
124
+ lines=1,
125
+ label="Dataset to load from Hugging Face Hub",
126
+ value="nbroad/basic_text_dataset",
127
+ )
128
+ ds_config = gr.Textbox(
129
+ lines=1, label="Dataset config (leave blank to use default)", value=""
130
+ )
131
+
132
+ column_name = gr.Textbox(lines=1, label="Enter column to embed", value="text")
133
+ ds_split = gr.Dropdown(
134
+ choices=["train", "validation", "test"],
135
+ label="Dataset split",
136
+ value="train",
137
+ )
138
+ # TODO: idx column
139
+ # TODO: text splitting options
140
+
141
+ with gr.Row():
142
+ model_choice = gr.Dropdown(
143
+ choices=model_options, label="Embedding model", value=model_options[0]
144
+ )
145
+ opt_desc = gr.Dropdown(
146
+ choices=optimization_options,
147
+ label="Optimization level",
148
+ value=optimization_options[0],
149
+ )
150
+
151
+ with gr.Row():
152
+ new_dataset_id = gr.Textbox(
153
+ lines=1,
154
+ label="New dataset name, including username",
155
+ value="nbroad/test-embeds",
156
+ )
157
+
158
+ num2skip = gr.Slider(
159
+ value=0,
160
+ minimum=0,
161
+ maximum=10_000_000,
162
+ step=1,
163
+ label="Number of rows to skip",
164
+ )
165
+
166
+ num2embed = gr.Slider(
167
+ value=-1,
168
+ minimum=-1,
169
+ maximum=10_000_000,
170
+ step=1,
171
+ label="Number of rows to embed (-1 = all)",
172
+ )
173
+
174
+ with gr.Row():
175
+ btn = gr.Button(value="Embed texts!")
176
+
177
+ last = gr.Textbox(value="")
178
+
179
+ btn.click(
180
+ fn=run,
181
+ inputs=[
182
+ ds_name,
183
+ ds_config,
184
+ column_name,
185
+ ds_split,
186
+ model_choice,
187
+ opt_desc,
188
+ new_dataset_id,
189
+ num2skip,
190
+ num2embed,
191
+ ],
192
+ outputs=last,
193
+ )
194
+
195
+
196
+ if __name__ == "__main__":
197
+ demo.queue(concurrency_count=20).launch(show_error=True, debug=True, share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets==2.13.1
2
+ tokenizers>=0.11.1,!=0.11.3,<0.14
3
+ optimum[onnxruntime-gpu]==1.8.8
4
+ transformers==4.30.1
5
+ accelerate==0.20.3
6
+ gradio==3.35.2
7
+ --extra-index-url https://download.pytorch.org/whl/cu118
8
+ torch==2.0.1
utils.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Union, Dict, List
6
+
7
+ import torch
8
+ import datasets
9
+ from datasets import load_dataset, Dataset
10
+ from transformers import AutoTokenizer, PreTrainedTokenizer
11
+ from huggingface_hub import Repository, create_repo, HfApi
12
+ from optimum.onnxruntime import (
13
+ AutoOptimizationConfig,
14
+ ORTModelForFeatureExtraction,
15
+ ORTOptimizer,
16
+ )
17
+
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+
20
+
21
+ opt_configs = {
22
+ "O2": AutoOptimizationConfig.O2(),
23
+ "O3": AutoOptimizationConfig.O3(),
24
+ "O4": AutoOptimizationConfig.O4(),
25
+ }
26
+
27
+
28
+ def get_batch_size(device_name: str, model_name: str, opt_level: str):
29
+ """
30
+ TODO: run actual tests
31
+
32
+ T4 has 16GB
33
+ A10 has 24GB
34
+
35
+ Args:
36
+ device_name (`str`):
37
+ The name of the GPU device in use.
38
+ model_name (`str`):
39
+ The name of the model in use.
40
+ opt_level (`str`):
41
+ The optimization level in use.
42
+
43
+ Returns:
44
+ `int`:
45
+ The batch size to use.
46
+ """
47
+
48
+ if "small" in model_name:
49
+ bs = 192
50
+ elif "base" in model_name:
51
+ bs = 128
52
+ elif "large" in model_name:
53
+ bs = 64
54
+ else:
55
+ bs = 32
56
+
57
+ if "A10" in device_name:
58
+ bs *= 2
59
+
60
+ if opt_level == "O4":
61
+ bs *= 2
62
+
63
+ return bs
64
+
65
+
66
+ def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
67
+ """
68
+ Mean pool the token embeddings.
69
+
70
+ Args:
71
+ last_hidden_state (`tuple`):
72
+ The output of the model.
73
+ attention_mask (`torch.Tensor`):
74
+ The attention mask.
75
+
76
+ Returns:
77
+ `torch.Tensor`:
78
+ The mean pooled embeddings.
79
+ """
80
+ input_mask_expanded = (
81
+ attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
82
+ )
83
+ return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(
84
+ input_mask_expanded.sum(1), min=1e-9
85
+ )
86
+
87
+
88
+ def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"):
89
+ """
90
+ Load a dataset from the HuggingFace Hub. Will be streaming so
91
+ as to not load the whole dataset to local storage.
92
+
93
+ Args:
94
+ ds_name (`str`):
95
+ The name of the dataset to load.
96
+ ds_config (`str`, *optional*, Defaults to `None`):
97
+ The configuration of the dataset to load.
98
+ ds_split (`str`, *optional*, Defaults to `"train"`):
99
+ The split of the dataset to load.
100
+
101
+ Returns:
102
+ ds (`datasets.IterableDataset`):
103
+ The loaded dataset.
104
+ """
105
+
106
+ if ds_config == "":
107
+ ds_config = None
108
+
109
+ ds = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
110
+
111
+ return ds
112
+
113
+
114
+ def get_model_and_tokenizer(model_name: str, optimization_level: str, progress):
115
+ """
116
+ Load the model and tokenizer from the HuggingFace Hub.
117
+
118
+ If the model is not already optimized, optimize it and save it to the local directory.
119
+
120
+ Args:
121
+ model_name (`str`):
122
+ The name of the model to load.
123
+ optimization_level (`str`):
124
+ The optimization level to use. Should be one of `"O2"`, `"O3"`, or `"O4"`.
125
+
126
+ Returns:
127
+ model (`ORTModelForFeatureExtraction`):
128
+ The optimized model.
129
+ tokenizer (`PreTrainedTokenizer`):
130
+ The tokenizer.
131
+ """
132
+ optimized_model_name = f"model_optimized_{optimization_level}.onnx"
133
+
134
+ model_dir = Path(model_name.replace("/", "_"))
135
+ if not (model_dir / optimized_model_name).exists():
136
+ if progress is not None:
137
+ progress(0.2, "Downloading tokenizer...")
138
+
139
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
140
+ tokenizer.save_pretrained(model_dir)
141
+
142
+ if progress is not None:
143
+ progress(0.4, "Downloading model...")
144
+
145
+ model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True)
146
+ model.save_pretrained(model_dir)
147
+
148
+ optimizer = ORTOptimizer.from_pretrained(model)
149
+ optimization_config = opt_configs[optimization_level]
150
+
151
+ if progress is not None:
152
+ progress(0.6, "Optimizing model...")
153
+
154
+ optimizer.optimize(save_dir=model_dir, optimization_config=optimization_config)
155
+ Path(model_dir / "model_optimized.onnx").rename(
156
+ model_dir / optimized_model_name
157
+ )
158
+
159
+ else:
160
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
161
+
162
+ if progress is not None:
163
+ progress(0.8, "Loading optimized model and tokenizer...")
164
+
165
+ return (
166
+ ORTModelForFeatureExtraction.from_pretrained(
167
+ model_dir,
168
+ file_name=optimized_model_name,
169
+ provider="CUDAExecutionProvider",
170
+ ),
171
+ tokenizer,
172
+ )
173
+
174
+
175
+ def tokenize(
176
+ examples: Dict[str, List[str]],
177
+ tokenizer: PreTrainedTokenizer,
178
+ column_name: str = "text",
179
+ padding: Union[bool, str] = True,
180
+ max_length: int = 512,
181
+ ):
182
+ """
183
+ Tokenize the examples using the tokenizer.
184
+
185
+ Args:
186
+ examples (`Dict[str, List[str]]`):
187
+ examples to tokenize
188
+ tokenizer (`PreTrainedTokenizer`):
189
+ tokenizer to use
190
+ column_name (`str`, *optional*, defaults to `text`):
191
+ column name to use for tokenization. Defaults to `text`
192
+ padding (`bool`, *optional*, defaults to `True`):
193
+ whether to pad the examples. Defaults to `True`
194
+ Use `"max_length"` if using `O4` optimization level
195
+ If `True`, the batch will be padded to the longest in the batch.
196
+ max_length (`int`, *optional*, Defaults to `512`):
197
+ max length to use for the model. Defaults to `512`.
198
+ Any sequences longer will be truncated.
199
+ If padding is `"max_length"`, the padding will be added until the sequence
200
+ is of length `max_length`.
201
+
202
+ Returns:
203
+ `Dict[str, List[List[int]]]`:
204
+ tokenized examples
205
+ """
206
+ # TODO: add lengths, sort by length, use dynamic padding
207
+ # TODO: option for controlling length for models that can go shorter/longer than 512
208
+ return tokenizer(
209
+ examples[column_name], truncation=True, padding=padding, max_length=max_length
210
+ )
211
+
212
+
213
+ @torch.inference_mode()
214
+ def batch_embed(
215
+ ds: datasets.IterableDataset,
216
+ model: ORTModelForFeatureExtraction,
217
+ tokenizer: PreTrainedTokenizer,
218
+ model_name: str,
219
+ column_name: str,
220
+ new_dataset_id: str,
221
+ opt_level: str,
222
+ upload_batch_size: int = 10_000,
223
+ map_batch_size: int = 2000,
224
+ num2skip: int = 0,
225
+ num2embed: int = -1,
226
+ progress=None,
227
+ ):
228
+ """
229
+ Run the model on the dataset and upload the embeddings to the hub.
230
+
231
+ Args:
232
+ ds (`datasets.Dataset`):
233
+ dataset to embed. From `load_hf_dataset`
234
+ model (`ORTModelForFeatureExtraction`):
235
+ model to use for embedding. From `get_model_and_tokenizer`
236
+ tokenizer (`AutoTokenizer`):
237
+ tokenizer to use for embedding. From `get_model_and_tokenizer`
238
+ model_name (`str`):
239
+ name of the model to use. Used to determine batch size.
240
+ column_name (`str`):
241
+ column name to use for embedding. Default option in gradio app is `text`
242
+ new_dataset_id (`str`):
243
+ id of the new dataset to create. Should include username or organization.
244
+ e.g. nbroad/new-embeddings
245
+ opt_level (`str`):
246
+ optimization level to use. Should be one of `O2`, `O3`, `O4`
247
+ See here for more details on optimization levels:
248
+ https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration
249
+ upload_batch_size (`int`, *optional*, defaults to `10_000`):
250
+ number of embeddings to upload at once. Defaults to 10,000.
251
+ map_batch_size (`int`, *optional*, defaults to `2000`):
252
+ number of examples to tokenize at once. Defaults to 2000.
253
+ num2skip (`int`, *optional*, defaults to `0`):
254
+ number of examples to skip. Defaults to 0.
255
+ num2embed (`int`, *optional*, defaults to `-1`):
256
+ number of examples to embed. Defaults to -1, which means all examples.
257
+
258
+ Returns:
259
+ current_count (`int`):
260
+ number of examples embedded so far
261
+ time_taken (`float`):
262
+ time taken to embed the examples in seconds
263
+
264
+ """
265
+
266
+ api = HfApi(
267
+ token=os.environ["HF_TOKEN"],
268
+ )
269
+
270
+ username = api.whoami()["name"]
271
+
272
+ if "/" in new_dataset_id:
273
+ new_dataset_id = username + "/" + new_dataset_id
274
+
275
+ repo = init_git_repo(new_dataset_id)
276
+
277
+ iterator = iter(
278
+ ds.map(
279
+ tokenize,
280
+ batched=True,
281
+ batch_size=map_batch_size,
282
+ fn_kwargs={
283
+ "tokenizer": tokenizer,
284
+ "column_name": column_name,
285
+ "padding": "max_length" if opt_level == "O4" else True,
286
+ },
287
+ remove_columns=ds.column_names,
288
+ )
289
+ )
290
+
291
+ embeds = []
292
+ texts = []
293
+
294
+ # last_count keeps track of how many had been embedded since last push
295
+ last_count = 0
296
+ # current count keeps track of how many have been embedded in total
297
+ current_count = 0
298
+
299
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
300
+
301
+ inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level)
302
+
303
+ loop = True
304
+
305
+ # skip through some examples
306
+ if num2skip > 0:
307
+ [next(iterator) for _ in range(num2skip)]
308
+
309
+ start_time = time.time()
310
+ while loop:
311
+ batch = [next(iterator, None) for _ in range(inference_bs)]
312
+
313
+ # batch will have None values when iterator runs out
314
+ if batch[-1] is None:
315
+ batch = [x for x in batch if x is not None]
316
+ loop = False
317
+ if len(batch) == 0:
318
+ break
319
+
320
+ ids = torch.tensor([b["input_ids"] for b in batch], device=device)
321
+ mask = torch.tensor([b["attention_mask"] for b in batch], device=device)
322
+ t_ids = torch.zeros_like(ids)
323
+
324
+ outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
325
+
326
+ embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
327
+ texts.extend([b[column_name] for b in batch])
328
+
329
+ current_count += len(batch)
330
+
331
+ # Check if we have embedded enough examples
332
+ if current_count >= num2embed:
333
+ diff = current_count - num2embed
334
+ embeds = embeds[:-diff]
335
+ texts = texts[:-diff]
336
+ current_count = num2embed
337
+ break
338
+
339
+ # Periodically upload to the hub
340
+ if len(embeds) > upload_batch_size:
341
+ push_to_repo(repo, last_count, current_count, embeds, texts)
342
+ embeds = []
343
+ last_count = current_count
344
+
345
+ # Provide updates
346
+ if progress is not None:
347
+ progress(
348
+ (current_count, None),
349
+ "Embedding docs...",
350
+ total=None,
351
+ unit="Docs Embedded",
352
+ )
353
+
354
+ time_taken = time.time() - start_time
355
+
356
+ # If there are any remaining embeddings, upload them
357
+ if len(embeds) > 0:
358
+ push_to_repo(repo, last_count, current_count, embeds, texts)
359
+
360
+ return current_count - num2skip, time_taken
361
+
362
+
363
+ def init_git_repo(repo_id: str):
364
+ """
365
+ Initialize a git repo for the new dataset.
366
+
367
+ ***Removes existing local folder if exists***
368
+
369
+ Args:
370
+ repo_id (`str`):
371
+ id of the new dataset to create. Should include username or organization.
372
+ e.g. nbroad/new-embeddings
373
+ """
374
+ local_dir = repo_id.replace("/", "_")
375
+
376
+ create_repo(
377
+ repo_id,
378
+ repo_type="dataset",
379
+ token=os.environ["HF_TOKEN"],
380
+ private=True,
381
+ exist_ok=True,
382
+ )
383
+ try:
384
+ repo = Repository(
385
+ local_dir=local_dir,
386
+ clone_from=repo_id,
387
+ repo_type="dataset",
388
+ token=os.environ["HF_TOKEN"],
389
+ skip_lfs_files=True,
390
+ )
391
+ except EnvironmentError:
392
+ shutil.rmtree(local_dir)
393
+ repo = Repository(
394
+ local_dir=local_dir,
395
+ clone_from=repo_id,
396
+ repo_type="dataset",
397
+ token=os.environ["HF_TOKEN"],
398
+ skip_lfs_files=True,
399
+ )
400
+
401
+ if repo is not None:
402
+ repo.git_pull()
403
+
404
+ return repo
405
+
406
+
407
+ def push_to_repo(
408
+ repo: str,
409
+ last_count: int,
410
+ current_count: int,
411
+ embeds: List[List[float]],
412
+ texts: List[str],
413
+ ):
414
+ """
415
+ Push embeddings to the repo.
416
+
417
+ Args:
418
+ repo (`huggingface_hub.Repository`):
419
+ repo to push to
420
+ last_count (`int`):
421
+ last count of embeddings.
422
+ This is the number of embeddings that have already been pushed.
423
+ current_count (`int`):
424
+ current count of embeddings.
425
+ This is the number of embeddings that have been pushed after this batch.
426
+ embeds (`List[List[float]]`):
427
+ list of embeddings to push to the repo
428
+ texts (`List[str]`):
429
+ list of texts to push to the repo
430
+ """
431
+
432
+ # TODO: write dataset loading script as well
433
+
434
+ temp_ds = Dataset.from_dict(
435
+ {
436
+ "embedding": embeds,
437
+ "text": texts,
438
+ }
439
+ )
440
+
441
+ data_dir = Path(repo.local_dir) / "data"
442
+ data_dir.mkdir(exist_ok=True, parents=True)
443
+
444
+ temp_ds.to_parquet(
445
+ str(data_dir / f"embeddings_{last_count}_{current_count}.parquet")
446
+ )
447
+
448
+ repo.push_to_hub(
449
+ commit_message=f"Embedded examples {last_count} thru {current_count}",
450
+ blocking=False,
451
+ auto_lfs_prune=True,
452
+ )