davidberenstein1957 HF staff commited on
Commit
f5ab4cb
·
1 Parent(s): b7c81a3

update deployment with API providers

Browse files
README.md CHANGED
@@ -76,21 +76,24 @@ launch()
76
 
77
  - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
78
 
79
- Optionally, you can set the following environment variables to customize the generation process.
80
 
81
  - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
- Optionally, you can use different models and APIs. For providers outside of Hugging Face, we provide an integration through [LiteLLM](https://docs.litellm.ai/docs/providers).
86
 
87
- - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`, `http://127.0.0.1:11434/v1/`.
88
- - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
89
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
 
 
 
90
 
91
  SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
92
 
93
- - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string.
 
94
 
95
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
96
 
 
76
 
77
  - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
78
 
79
+ You can set the following environment variables to customize the generation process.
80
 
81
  - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
+ Optionally, you can use different API providers and models.
86
 
87
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
 
88
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
89
+ - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
90
+ - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/v1/`.
91
+ - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
92
 
93
  SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
94
 
95
+ - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
96
+ - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
97
 
98
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
99
 
examples/enforce_mapgie_template copy.py DELETED
@@ -1,9 +0,0 @@
1
- # pip install synthetic-dataset-generator
2
- import os
3
-
4
- from synthetic_dataset_generator import launch
5
-
6
- os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "my_custom_template"
7
- os.environ["MODEL"] = "google/gemma-2-9b-it"
8
-
9
- launch()
 
 
 
 
 
 
 
 
 
 
examples/{ollama_local.py → hf-serverless_deployment.py} RENAMED
@@ -4,7 +4,7 @@ import os
4
  from synthetic_dataset_generator import launch
5
 
6
  assert os.getenv("HF_TOKEN") # push the data to huggingface
7
- os.environ["BASE_URL"] = "http://127.0.0.1:11434/v1/"
8
- os.environ["MODEL"] = "llama3.1"
9
 
10
  launch()
 
4
  from synthetic_dataset_generator import launch
5
 
6
  assert os.getenv("HF_TOKEN") # push the data to huggingface
7
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
8
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
9
 
10
  launch()
examples/ollama_deployment.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install synthetic-dataset-generator
2
+ # ollama serve
3
+ # ollama run llama3.1:8b-instruct-q8_0
4
+ import os
5
+
6
+ from synthetic_dataset_generator import launch
7
+
8
+ assert os.getenv("HF_TOKEN") # push the data to huggingface
9
+ os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/"
10
+ os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0"
11
+ os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct"
12
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
13
+ os.environ["MAX_NUM_ROWS"] = "10000"
14
+ os.environ["DEFAULT_BATCH_SIZE"] = "5"
15
+ os.environ["MAX_NUM_TOKENS"] = "2048"
16
+
17
+ launch()
examples/{openai_local.py → openai_deployment.py} RENAMED
@@ -4,8 +4,9 @@ import os
4
  from synthetic_dataset_generator import launch
5
 
6
  assert os.getenv("HF_TOKEN") # push the data to huggingface
7
- os.environ["BASE_URL"] = "https://api.openai.com/v1/"
8
  os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")
9
  os.environ["MODEL"] = "gpt-4o"
 
10
 
11
  launch()
 
4
  from synthetic_dataset_generator import launch
5
 
6
  assert os.getenv("HF_TOKEN") # push the data to huggingface
7
+ os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/"
8
  os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")
9
  os.environ["MODEL"] = "gpt-4o"
10
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
11
 
12
  launch()
examples/tgi_or_hf_dedicated.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install synthetic-dataset-generator
2
+ import os
3
+
4
+ from synthetic_dataset_generator import launch
5
+
6
+ assert os.getenv("HF_TOKEN") # push the data to huggingface
7
+ os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/"
8
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
9
+ os.environ["TOKENIZER_ID"] = (
10
+ "meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
11
+ )
12
+ os.environ["MODEL"] = None # model is linked to endpoint
13
+
14
+ launch()
pdm.lock CHANGED
@@ -5,7 +5,7 @@
5
  groups = ["default"]
6
  strategy = ["inherit_metadata"]
7
  lock_version = "4.5.0"
8
- content_hash = "sha256:a6d7f86e9a168e7eb78801faafa8a9ea13270310ee9c21a0e23aeab277d973a0"
9
 
10
  [[metadata.targets]]
11
  requires_python = ">=3.10,<3.13"
@@ -491,8 +491,11 @@ files = [
491
 
492
  [[package]]
493
  name = "distilabel"
494
- version = "1.4.1"
495
  requires_python = ">=3.9"
 
 
 
496
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
497
  groups = ["default"]
498
  dependencies = [
@@ -512,30 +515,30 @@ dependencies = [
512
  "typer>=0.9.0",
513
  "universal-pathlib>=0.2.2",
514
  ]
515
- files = [
516
- {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
517
- {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
518
- ]
519
 
520
  [[package]]
521
  name = "distilabel"
522
- version = "1.4.1"
523
- extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"]
524
  requires_python = ">=3.9"
 
 
 
525
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
526
  groups = ["default"]
527
  dependencies = [
528
  "argilla>=2.0.0",
529
- "distilabel==1.4.1",
530
  "huggingface-hub>=0.22.0",
531
  "instructor>=1.2.3",
532
  "ipython",
 
533
  "numba>=0.54.0",
 
 
534
  "outlines>=0.0.40",
535
- ]
536
- files = [
537
- {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
538
- {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
539
  ]
540
 
541
  [[package]]
@@ -824,7 +827,7 @@ files = [
824
 
825
  [[package]]
826
  name = "httpx"
827
- version = "0.28.1"
828
  requires_python = ">=3.8"
829
  summary = "The next generation HTTP client."
830
  groups = ["default"]
@@ -833,10 +836,11 @@ dependencies = [
833
  "certifi",
834
  "httpcore==1.*",
835
  "idna",
 
836
  ]
837
  files = [
838
- {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
839
- {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
840
  ]
841
 
842
  [[package]]
@@ -1068,6 +1072,22 @@ files = [
1068
  {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
1069
  ]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  [[package]]
1072
  name = "llvmlite"
1073
  version = "0.43.0"
@@ -1538,6 +1558,21 @@ files = [
1538
  {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
1539
  ]
1540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1541
  [[package]]
1542
  name = "openai"
1543
  version = "1.57.4"
 
5
  groups = ["default"]
6
  strategy = ["inherit_metadata"]
7
  lock_version = "4.5.0"
8
+ content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350"
9
 
10
  [[metadata.targets]]
11
  requires_python = ">=3.10,<3.13"
 
491
 
492
  [[package]]
493
  name = "distilabel"
494
+ version = "1.5.0"
495
  requires_python = ">=3.9"
496
+ git = "https://github.com/argilla-io/distilabel.git"
497
+ ref = "feat/add-magpie-support-llama-cpp-ollama"
498
+ revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
499
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
500
  groups = ["default"]
501
  dependencies = [
 
515
  "typer>=0.9.0",
516
  "universal-pathlib>=0.2.2",
517
  ]
 
 
 
 
518
 
519
  [[package]]
520
  name = "distilabel"
521
+ version = "1.5.0"
522
+ extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"]
523
  requires_python = ">=3.9"
524
+ git = "https://github.com/argilla-io/distilabel.git"
525
+ ref = "feat/add-magpie-support-llama-cpp-ollama"
526
+ revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
527
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
528
  groups = ["default"]
529
  dependencies = [
530
  "argilla>=2.0.0",
531
+ "distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
532
  "huggingface-hub>=0.22.0",
533
  "instructor>=1.2.3",
534
  "ipython",
535
+ "llama-cpp-python>=0.2.0",
536
  "numba>=0.54.0",
537
+ "ollama>=0.1.7",
538
+ "openai>=1.0.0",
539
  "outlines>=0.0.40",
540
+ "torch>=2.0.0",
541
+ "transformers>=4.34.1",
 
 
542
  ]
543
 
544
  [[package]]
 
827
 
828
  [[package]]
829
  name = "httpx"
830
+ version = "0.27.2"
831
  requires_python = ">=3.8"
832
  summary = "The next generation HTTP client."
833
  groups = ["default"]
 
836
  "certifi",
837
  "httpcore==1.*",
838
  "idna",
839
+ "sniffio",
840
  ]
841
  files = [
842
+ {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
843
+ {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
844
  ]
845
 
846
  [[package]]
 
1072
  {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
1073
  ]
1074
 
1075
+ [[package]]
1076
+ name = "llama-cpp-python"
1077
+ version = "0.3.5"
1078
+ requires_python = ">=3.8"
1079
+ summary = "Python bindings for the llama.cpp library"
1080
+ groups = ["default"]
1081
+ dependencies = [
1082
+ "diskcache>=5.6.1",
1083
+ "jinja2>=2.11.3",
1084
+ "numpy>=1.20.0",
1085
+ "typing-extensions>=4.5.0",
1086
+ ]
1087
+ files = [
1088
+ {file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"},
1089
+ ]
1090
+
1091
  [[package]]
1092
  name = "llvmlite"
1093
  version = "0.43.0"
 
1558
  {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
1559
  ]
1560
 
1561
+ [[package]]
1562
+ name = "ollama"
1563
+ version = "0.4.4"
1564
+ requires_python = "<4.0,>=3.8"
1565
+ summary = "The official Python client for Ollama."
1566
+ groups = ["default"]
1567
+ dependencies = [
1568
+ "httpx<0.28.0,>=0.27.0",
1569
+ "pydantic<3.0.0,>=2.9.0",
1570
+ ]
1571
+ files = [
1572
+ {file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"},
1573
+ {file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"},
1574
+ ]
1575
+
1576
  [[package]]
1577
  name = "openai"
1578
  version = "1.57.4"
pyproject.toml CHANGED
@@ -18,7 +18,7 @@ readme = "README.md"
18
  license = {text = "Apache 2"}
19
 
20
  dependencies = [
21
- "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1,<2.0.0",
22
  "gradio[oauth]>=5.4.0,<6.0.0",
23
  "transformers>=4.44.2,<5.0.0",
24
  "sentence-transformers>=3.2.0,<4.0.0",
 
18
  license = {text = "Apache 2"}
19
 
20
  dependencies = [
21
+ "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
22
  "gradio[oauth]>=5.4.0,<6.0.0",
23
  "transformers>=4.44.2,<5.0.0",
24
  "sentence-transformers>=3.2.0,<4.0.0",
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -121,7 +121,7 @@ def generate_dataset(
121
  {
122
  "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
123
  }
124
- for i in range(int(num_rows / 50))
125
  ]
126
  batch = list(prompt_rewriter.process(inputs=inputs))
127
  prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
 
121
  {
122
  "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
123
  }
124
+ for i in range(int(num_rows / 100))
125
  ]
126
  batch = list(prompt_rewriter.process(inputs=inputs))
127
  prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -178,26 +178,41 @@ def generate_dataset(
178
 
179
  dataframe = pd.DataFrame(distiset_results)
180
  if multi_label:
181
- dataframe["labels"] = dataframe["labels"].apply(
182
- lambda x: list(
183
- set(
184
- [
185
- label.lower().strip()
186
- if (label is not None and label.lower().strip() in labels)
187
- else random.choice(labels)
188
- for label in x
189
- ]
190
- )
191
- )
192
- )
 
 
193
  dataframe = dataframe[dataframe["labels"].notna()]
194
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  dataframe = dataframe.rename(columns={"labels": "label"})
196
- dataframe["label"] = dataframe["label"].apply(
197
- lambda x: x.lower().strip()
198
- if x and x.lower().strip() in labels
199
- else random.choice(labels)
200
- )
201
  dataframe = dataframe[dataframe["text"].notna()]
202
 
203
  progress(1.0, desc="Dataset created")
 
178
 
179
  dataframe = pd.DataFrame(distiset_results)
180
  if multi_label:
181
+
182
+ def _validate_labels(x):
183
+ if isinstance(x, str): # single label
184
+ return [x.lower().strip()]
185
+ elif isinstance(x, list): # multiple labels
186
+ return [
187
+ label.lower().strip()
188
+ for label in x
189
+ if label.lower().strip() in labels
190
+ ]
191
+ else:
192
+ return [random.choice(labels)]
193
+
194
+ dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
195
  dataframe = dataframe[dataframe["labels"].notna()]
196
  else:
197
+
198
+ def _validate_labels(x):
199
+ if isinstance(x, str) and x.lower().strip() in labels:
200
+ return x.lower().strip()
201
+ elif isinstance(x, list):
202
+ options = [
203
+ label.lower().strip()
204
+ for label in x
205
+ if isinstance(label, str) and label.lower().strip() in labels
206
+ ]
207
+ if options:
208
+ return random.choice(options)
209
+ else:
210
+ return random.choice(labels)
211
+ else:
212
+ return random.choice(labels)
213
+
214
  dataframe = dataframe.rename(columns={"labels": "label"})
215
+ dataframe["label"] = dataframe["label"].apply(_validate_labels)
 
 
 
 
216
  dataframe = dataframe[dataframe["text"].notna()]
217
 
218
  progress(1.0, desc="Dataset created")
src/synthetic_dataset_generator/constants.py CHANGED
@@ -7,39 +7,63 @@ import argilla as rg
7
  TEXTCAT_TASK = "text_classification"
8
  SFT_TASK = "supervised_fine_tuning"
9
 
10
- # Hugging Face
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if not HF_TOKEN:
13
  raise ValueError(
14
  "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
15
  )
16
 
17
- # Inference
18
- MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
19
- MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
20
- DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
21
- MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
22
- BASE_URL = os.getenv("BASE_URL", default=None)
23
-
24
  _API_KEY = os.getenv("API_KEY")
25
- if _API_KEY:
26
- API_KEYS = [_API_KEY]
27
- else:
28
- API_KEYS = [os.getenv("HF_TOKEN")] + [
29
- os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)
30
- ]
31
  API_KEYS = [token for token in API_KEYS if token]
32
 
33
  # Determine if SFT is available
34
  SFT_AVAILABLE = False
35
  llama_options = ["llama3", "llama-3", "llama 3"]
36
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
37
- if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
 
38
  SFT_AVAILABLE = True
39
- passed_pre_query_template = os.getenv("MAGPIE_PRE_QUERY_TEMPLATE")
40
- if passed_pre_query_template.lower() in llama_options:
41
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
42
- elif passed_pre_query_template.lower() in qwen_options:
43
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
44
  else:
45
  MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
@@ -54,12 +78,12 @@ elif MODEL.lower() in qwen_options or any(
54
  SFT_AVAILABLE = True
55
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
56
 
57
- if BASE_URL:
58
  SFT_AVAILABLE = False
59
 
60
  if not SFT_AVAILABLE:
61
  warnings.warn(
62
- message="`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints to generate chat data."
63
  )
64
  MAGPIE_PRE_QUERY_TEMPLATE = None
65
 
@@ -67,11 +91,12 @@ if not SFT_AVAILABLE:
67
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
68
 
69
  # Argilla
70
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
71
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
72
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
73
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
74
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
 
75
 
76
  if not ARGILLA_API_URL or not ARGILLA_API_KEY:
77
  warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
 
7
  TEXTCAT_TASK = "text_classification"
8
  SFT_TASK = "supervised_fine_tuning"
9
 
10
+ # Inference
11
+ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
12
+ MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
13
+ DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
14
+
15
+ # Models
16
+ MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
17
+ TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
18
+ OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
19
+ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
20
+ HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
21
+ if HUGGINGFACE_BASE_URL and MODEL:
22
+ raise ValueError(
23
+ "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
24
+ )
25
+ if OPENAI_BASE_URL or OLLAMA_BASE_URL:
26
+ if not MODEL:
27
+ raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
28
+
29
+
30
+
31
+ # Check if multiple base URLs are provided
32
+ base_urls = [
33
+ url for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL] if url
34
+ ]
35
+ if len(base_urls) > 1:
36
+ raise ValueError(
37
+ f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
38
+ )
39
+ BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL
40
+
41
+
42
+ # API Keys
43
  HF_TOKEN = os.getenv("HF_TOKEN")
44
  if not HF_TOKEN:
45
  raise ValueError(
46
  "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
47
  )
48
 
 
 
 
 
 
 
 
49
  _API_KEY = os.getenv("API_KEY")
50
+ API_KEYS = (
51
+ [_API_KEY]
52
+ if _API_KEY
53
+ else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
54
+ )
 
55
  API_KEYS = [token for token in API_KEYS if token]
56
 
57
  # Determine if SFT is available
58
  SFT_AVAILABLE = False
59
  llama_options = ["llama3", "llama-3", "llama 3"]
60
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
61
+
62
+ if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
63
  SFT_AVAILABLE = True
64
+ if passed_pre_query_template in llama_options:
 
65
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
66
+ elif passed_pre_query_template in qwen_options:
67
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
68
  else:
69
  MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
 
78
  SFT_AVAILABLE = True
79
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
80
 
81
+ if OPENAI_BASE_URL:
82
  SFT_AVAILABLE = False
83
 
84
  if not SFT_AVAILABLE:
85
  warnings.warn(
86
+ "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`."
87
  )
88
  MAGPIE_PRE_QUERY_TEMPLATE = None
89
 
 
91
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
92
 
93
  # Argilla
94
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
95
+ "ARGILLA_API_URL_SDG_REVIEWER"
96
+ )
97
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
98
+ "ARGILLA_API_KEY_SDG_REVIEWER"
99
+ )
100
 
101
  if not ARGILLA_API_URL or not ARGILLA_API_KEY:
102
  warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -1,4 +1,15 @@
1
- from synthetic_dataset_generator.constants import API_KEYS
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  TOKEN_INDEX = 0
4
 
@@ -8,3 +19,71 @@ def _get_next_api_key():
8
  api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
9
  TOKEN_INDEX += 1
10
  return api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
3
+
4
+ from synthetic_dataset_generator.constants import (
5
+ API_KEYS,
6
+ HUGGINGFACE_BASE_URL,
7
+ MAGPIE_PRE_QUERY_TEMPLATE,
8
+ MODEL,
9
+ OLLAMA_BASE_URL,
10
+ OPENAI_BASE_URL,
11
+ TOKENIZER_ID,
12
+ )
13
 
14
  TOKEN_INDEX = 0
15
 
 
19
  api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
20
  TOKEN_INDEX += 1
21
  return api_key
22
+
23
+
24
+ def _get_llm(use_magpie_template=False, **kwargs):
25
+ if OPENAI_BASE_URL:
26
+ llm = OpenAILLM(
27
+ model=MODEL,
28
+ base_url=OPENAI_BASE_URL,
29
+ api_key=_get_next_api_key(),
30
+ **kwargs,
31
+ )
32
+ if "generation_kwargs" in kwargs:
33
+ if "stop_sequences" in kwargs["generation_kwargs"]:
34
+ kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
35
+ "stop_sequences"
36
+ ]
37
+ del kwargs["generation_kwargs"]["stop_sequences"]
38
+ if "do_sample" in kwargs["generation_kwargs"]:
39
+ del kwargs["generation_kwargs"]["do_sample"]
40
+ elif OLLAMA_BASE_URL:
41
+ if "generation_kwargs" in kwargs:
42
+ if "max_new_tokens" in kwargs["generation_kwargs"]:
43
+ kwargs["generation_kwargs"]["num_predict"] = kwargs[
44
+ "generation_kwargs"
45
+ ]["max_new_tokens"]
46
+ del kwargs["generation_kwargs"]["max_new_tokens"]
47
+ if "stop_sequences" in kwargs["generation_kwargs"]:
48
+ kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
49
+ "stop_sequences"
50
+ ]
51
+ del kwargs["generation_kwargs"]["stop_sequences"]
52
+ if "do_sample" in kwargs["generation_kwargs"]:
53
+ del kwargs["generation_kwargs"]["do_sample"]
54
+ options = kwargs["generation_kwargs"]
55
+ del kwargs["generation_kwargs"]
56
+ kwargs["generation_kwargs"] = {}
57
+ kwargs["generation_kwargs"]["options"] = options
58
+ llm = OllamaLLM(
59
+ model=MODEL,
60
+ host=OLLAMA_BASE_URL,
61
+ tokenizer_id=TOKENIZER_ID or MODEL,
62
+ **kwargs,
63
+ )
64
+ elif HUGGINGFACE_BASE_URL:
65
+ kwargs["generation_kwargs"]["do_sample"] = True
66
+ llm = InferenceEndpointsLLM(
67
+ api_key=_get_next_api_key(),
68
+ base_url=HUGGINGFACE_BASE_URL,
69
+ tokenizer_id=TOKENIZER_ID or MODEL,
70
+ **kwargs,
71
+ )
72
+ else:
73
+ llm = InferenceEndpointsLLM(
74
+ api_key=_get_next_api_key(),
75
+ tokenizer_id=TOKENIZER_ID or MODEL,
76
+ model_id=MODEL,
77
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
78
+ **kwargs,
79
+ )
80
+
81
+ return llm
82
+
83
+
84
+ try:
85
+ llm = _get_llm()
86
+ llm.load()
87
+ llm.generate([[{"content": "Hello, world!", "role": "user"}]])
88
+ except Exception as e:
89
+ gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -1,4 +1,3 @@
1
- from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
  from synthetic_dataset_generator.constants import (
@@ -7,7 +6,7 @@ from synthetic_dataset_generator.constants import (
7
  MAX_NUM_TOKENS,
8
  MODEL,
9
  )
10
- from synthetic_dataset_generator.pipelines.base import _get_next_api_key
11
 
12
  INFORMATION_SEEKING_PROMPT = (
13
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -149,18 +148,13 @@ def _get_output_mappings(num_turns):
149
 
150
 
151
  def get_prompt_generator():
 
 
 
 
 
152
  prompt_generator = TextGeneration(
153
- llm=InferenceEndpointsLLM(
154
- api_key=_get_next_api_key(),
155
- model_id=MODEL,
156
- tokenizer_id=MODEL,
157
- base_url=BASE_URL,
158
- generation_kwargs={
159
- "temperature": 0.8,
160
- "max_new_tokens": MAX_NUM_TOKENS,
161
- "do_sample": True,
162
- },
163
- ),
164
  system_prompt=PROMPT_CREATION_PROMPT,
165
  use_system_prompt=True,
166
  )
@@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
172
  input_mappings = _get_output_mappings(num_turns)
173
  output_mappings = input_mappings.copy()
174
  if num_turns == 1:
 
 
 
 
 
 
175
  magpie_generator = Magpie(
176
- llm=InferenceEndpointsLLM(
177
- model_id=MODEL,
178
- tokenizer_id=MODEL,
179
- base_url=BASE_URL,
180
- api_key=_get_next_api_key(),
181
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
182
- generation_kwargs={
183
- "temperature": temperature,
184
- "do_sample": True,
185
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
186
- "stop_sequences": _STOP_SEQUENCES,
187
- },
188
  ),
189
  n_turns=num_turns,
190
  output_mappings=output_mappings,
191
  only_instruction=True,
192
  )
193
  else:
 
 
 
 
 
 
194
  magpie_generator = Magpie(
195
- llm=InferenceEndpointsLLM(
196
- model_id=MODEL,
197
- tokenizer_id=MODEL,
198
- base_url=BASE_URL,
199
- api_key=_get_next_api_key(),
200
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
201
- generation_kwargs={
202
- "temperature": temperature,
203
- "do_sample": True,
204
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
205
- "stop_sequences": _STOP_SEQUENCES,
206
- },
207
  ),
208
  end_with_user=True,
209
  n_turns=num_turns,
@@ -214,50 +204,33 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
214
 
215
 
216
  def get_prompt_rewriter():
217
- prompt_rewriter = TextGeneration(
218
- llm=InferenceEndpointsLLM(
219
- model_id=MODEL,
220
- tokenizer_id=MODEL,
221
- base_url=BASE_URL,
222
- api_key=_get_next_api_key(),
223
- generation_kwargs={
224
- "temperature": 1,
225
- },
226
- ),
227
- )
228
  prompt_rewriter.load()
229
  return prompt_rewriter
230
 
231
 
232
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
233
  if num_turns == 1:
 
 
 
 
234
  response_generator = TextGeneration(
235
- llm=InferenceEndpointsLLM(
236
- model_id=MODEL,
237
- tokenizer_id=MODEL,
238
- base_url=BASE_URL,
239
- api_key=_get_next_api_key(),
240
- generation_kwargs={
241
- "temperature": temperature,
242
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
243
- },
244
- ),
245
  system_prompt=system_prompt,
246
  output_mappings={"generation": "completion"},
247
  input_mappings={"instruction": "prompt"},
248
  )
249
  else:
 
 
 
 
250
  response_generator = ChatGeneration(
251
- llm=InferenceEndpointsLLM(
252
- model_id=MODEL,
253
- tokenizer_id=MODEL,
254
- base_url=BASE_URL,
255
- api_key=_get_next_api_key(),
256
- generation_kwargs={
257
- "temperature": temperature,
258
- "max_new_tokens": MAX_NUM_TOKENS,
259
- },
260
- ),
261
  output_mappings={"generation": "completion"},
262
  input_mappings={"conversation": "messages"},
263
  )
@@ -293,7 +266,7 @@ with Pipeline(name="sft") as pipeline:
293
  "max_new_tokens": {MAX_NUM_TOKENS},
294
  "stop_sequences": {_STOP_SEQUENCES}
295
  }},
296
- api_key=os.environ["BASE_URL"],
297
  ),
298
  n_turns={num_turns},
299
  num_rows={num_rows},
 
 
1
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
2
 
3
  from synthetic_dataset_generator.constants import (
 
6
  MAX_NUM_TOKENS,
7
  MODEL,
8
  )
9
+ from synthetic_dataset_generator.pipelines.base import _get_llm
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
148
 
149
 
150
  def get_prompt_generator():
151
+ generation_kwargs = {
152
+ "temperature": 0.8,
153
+ "max_new_tokens": MAX_NUM_TOKENS,
154
+ "do_sample": True,
155
+ }
156
  prompt_generator = TextGeneration(
157
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
 
158
  system_prompt=PROMPT_CREATION_PROMPT,
159
  use_system_prompt=True,
160
  )
 
166
  input_mappings = _get_output_mappings(num_turns)
167
  output_mappings = input_mappings.copy()
168
  if num_turns == 1:
169
+ generation_kwargs = {
170
+ "temperature": temperature,
171
+ "do_sample": True,
172
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
173
+ "stop_sequences": _STOP_SEQUENCES,
174
+ }
175
  magpie_generator = Magpie(
176
+ llm=_get_llm(
177
+ generation_kwargs=generation_kwargs,
 
 
 
178
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
179
+ use_magpie_template=True,
 
 
 
 
 
180
  ),
181
  n_turns=num_turns,
182
  output_mappings=output_mappings,
183
  only_instruction=True,
184
  )
185
  else:
186
+ generation_kwargs = {
187
+ "temperature": temperature,
188
+ "do_sample": True,
189
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
190
+ "stop_sequences": _STOP_SEQUENCES,
191
+ }
192
  magpie_generator = Magpie(
193
+ llm=_get_llm(
194
+ generation_kwargs=generation_kwargs,
 
 
 
195
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
196
+ use_magpie_template=True,
 
 
 
 
 
197
  ),
198
  end_with_user=True,
199
  n_turns=num_turns,
 
204
 
205
 
206
  def get_prompt_rewriter():
207
+ generation_kwargs = {
208
+ "temperature": 1,
209
+ }
210
+ prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs))
 
 
 
 
 
 
 
211
  prompt_rewriter.load()
212
  return prompt_rewriter
213
 
214
 
215
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
216
  if num_turns == 1:
217
+ generation_kwargs = {
218
+ "temperature": temperature,
219
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
220
+ }
221
  response_generator = TextGeneration(
222
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
223
  system_prompt=system_prompt,
224
  output_mappings={"generation": "completion"},
225
  input_mappings={"instruction": "prompt"},
226
  )
227
  else:
228
+ generation_kwargs = {
229
+ "temperature": temperature,
230
+ "max_new_tokens": MAX_NUM_TOKENS,
231
+ }
232
  response_generator = ChatGeneration(
233
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
234
  output_mappings={"generation": "completion"},
235
  input_mappings={"conversation": "messages"},
236
  )
 
266
  "max_new_tokens": {MAX_NUM_TOKENS},
267
  "stop_sequences": {_STOP_SEQUENCES}
268
  }},
269
+ api_key=os.environ["API_KEY"],
270
  ),
271
  n_turns={num_turns},
272
  num_rows={num_rows},
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,7 +1,6 @@
1
  import random
2
  from typing import List
3
 
4
- from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
7
  TextClassification,
@@ -9,8 +8,12 @@ from distilabel.steps.tasks import (
9
  )
10
  from pydantic import BaseModel, Field
11
 
12
- from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
13
- from synthetic_dataset_generator.pipelines.base import _get_next_api_key
 
 
 
 
14
  from synthetic_dataset_generator.utils import get_preprocess_labels
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
@@ -69,23 +72,10 @@ def get_prompt_generator():
69
  "temperature": 0.8,
70
  "max_new_tokens": MAX_NUM_TOKENS,
71
  }
72
- if BASE_URL:
73
- llm = OpenAILLM(
74
- model=MODEL,
75
- base_url=BASE_URL,
76
- api_key=_get_next_api_key(),
77
- structured_output=structured_output,
78
- generation_kwargs=generation_kwargs,
79
- )
80
- else:
81
- generation_kwargs["do_sample"] = True
82
- llm = InferenceEndpointsLLM(
83
- api_key=_get_next_api_key(),
84
- model_id=MODEL,
85
- base_url=BASE_URL,
86
- structured_output=structured_output,
87
- generation_kwargs=generation_kwargs,
88
- )
89
 
90
  prompt_generator = TextGeneration(
91
  llm=llm,
@@ -103,21 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
103
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
104
  "top_p": 0.95,
105
  }
106
- if BASE_URL:
107
- llm = OpenAILLM(
108
- model=MODEL,
109
- base_url=BASE_URL,
110
- api_key=_get_next_api_key(),
111
- generation_kwargs=generation_kwargs,
112
- )
113
- else:
114
- generation_kwargs["do_sample"] = True
115
- llm = InferenceEndpointsLLM(
116
- model_id=MODEL,
117
- base_url=BASE_URL,
118
- api_key=_get_next_api_key(),
119
- generation_kwargs=generation_kwargs,
120
- )
121
 
122
  textcat_generator = GenerateTextClassificationData(
123
  llm=llm,
@@ -134,21 +110,7 @@ def get_labeller_generator(system_prompt, labels, multi_label):
134
  "temperature": 0.01,
135
  "max_new_tokens": MAX_NUM_TOKENS,
136
  }
137
-
138
- if BASE_URL:
139
- llm = OpenAILLM(
140
- model=MODEL,
141
- base_url=BASE_URL,
142
- api_key=_get_next_api_key(),
143
- generation_kwargs=generation_kwargs,
144
- )
145
- else:
146
- llm = InferenceEndpointsLLM(
147
- model_id=MODEL,
148
- base_url=BASE_URL,
149
- api_key=_get_next_api_key(),
150
- generation_kwargs=generation_kwargs,
151
- )
152
 
153
  labeller_generator = TextClassification(
154
  llm=llm,
 
1
  import random
2
  from typing import List
3
 
 
4
  from distilabel.steps.tasks import (
5
  GenerateTextClassificationData,
6
  TextClassification,
 
8
  )
9
  from pydantic import BaseModel, Field
10
 
11
+ from synthetic_dataset_generator.constants import (
12
+ BASE_URL,
13
+ MAX_NUM_TOKENS,
14
+ MODEL,
15
+ )
16
+ from synthetic_dataset_generator.pipelines.base import _get_llm
17
  from synthetic_dataset_generator.utils import get_preprocess_labels
18
 
19
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
 
72
  "temperature": 0.8,
73
  "max_new_tokens": MAX_NUM_TOKENS,
74
  }
75
+ llm = _get_llm(
76
+ structured_output=structured_output,
77
+ generation_kwargs=generation_kwargs,
78
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  prompt_generator = TextGeneration(
81
  llm=llm,
 
93
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
94
  "top_p": 0.95,
95
  }
96
+ llm = _get_llm(generation_kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  textcat_generator = GenerateTextClassificationData(
99
  llm=llm,
 
110
  "temperature": 0.01,
111
  "max_new_tokens": MAX_NUM_TOKENS,
112
  }
113
+ llm = _get_llm(generation_kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  labeller_generator = TextClassification(
116
  llm=llm,