Commit
·
f5ab4cb
1
Parent(s):
b7c81a3
update deployment with API providers
Browse files- README.md +8 -5
- examples/enforce_mapgie_template copy.py +0 -9
- examples/{ollama_local.py → hf-serverless_deployment.py} +2 -2
- examples/ollama_deployment.py +17 -0
- examples/{openai_local.py → openai_deployment.py} +2 -1
- examples/tgi_or_hf_dedicated.py +14 -0
- pdm.lock +51 -16
- pyproject.toml +1 -1
- src/synthetic_dataset_generator/apps/chat.py +1 -1
- src/synthetic_dataset_generator/apps/textcat.py +32 -17
- src/synthetic_dataset_generator/constants.py +50 -25
- src/synthetic_dataset_generator/pipelines/base.py +80 -1
- src/synthetic_dataset_generator/pipelines/chat.py +40 -67
- src/synthetic_dataset_generator/pipelines/textcat.py +12 -50
README.md
CHANGED
@@ -76,21 +76,24 @@ launch()
|
|
76 |
|
77 |
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
|
78 |
|
79 |
-
|
80 |
|
81 |
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
82 |
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
83 |
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
84 |
|
85 |
-
Optionally, you can use different
|
86 |
|
87 |
-
- `
|
88 |
-
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
|
89 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
|
|
|
|
|
|
90 |
|
91 |
SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
|
92 |
|
93 |
-
- `
|
|
|
94 |
|
95 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
96 |
|
|
|
76 |
|
77 |
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
|
78 |
|
79 |
+
You can set the following environment variables to customize the generation process.
|
80 |
|
81 |
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
82 |
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
83 |
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
84 |
|
85 |
+
Optionally, you can use different API providers and models.
|
86 |
|
87 |
+
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
|
|
88 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
89 |
+
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
90 |
+
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/v1/`.
|
91 |
+
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
92 |
|
93 |
SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
|
94 |
|
95 |
+
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
|
96 |
+
- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
|
97 |
|
98 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
99 |
|
examples/enforce_mapgie_template copy.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# pip install synthetic-dataset-generator
|
2 |
-
import os
|
3 |
-
|
4 |
-
from synthetic_dataset_generator import launch
|
5 |
-
|
6 |
-
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "my_custom_template"
|
7 |
-
os.environ["MODEL"] = "google/gemma-2-9b-it"
|
8 |
-
|
9 |
-
launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/{ollama_local.py → hf-serverless_deployment.py}
RENAMED
@@ -4,7 +4,7 @@ import os
|
|
4 |
from synthetic_dataset_generator import launch
|
5 |
|
6 |
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
7 |
-
os.environ["
|
8 |
-
os.environ["
|
9 |
|
10 |
launch()
|
|
|
4 |
from synthetic_dataset_generator import launch
|
5 |
|
6 |
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
7 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
|
8 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
9 |
|
10 |
launch()
|
examples/ollama_deployment.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install synthetic-dataset-generator
|
2 |
+
# ollama serve
|
3 |
+
# ollama run llama3.1:8b-instruct-q8_0
|
4 |
+
import os
|
5 |
+
|
6 |
+
from synthetic_dataset_generator import launch
|
7 |
+
|
8 |
+
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
9 |
+
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/"
|
10 |
+
os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0"
|
11 |
+
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct"
|
12 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
|
13 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
14 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "5"
|
15 |
+
os.environ["MAX_NUM_TOKENS"] = "2048"
|
16 |
+
|
17 |
+
launch()
|
examples/{openai_local.py → openai_deployment.py}
RENAMED
@@ -4,8 +4,9 @@ import os
|
|
4 |
from synthetic_dataset_generator import launch
|
5 |
|
6 |
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
7 |
-
os.environ["
|
8 |
os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")
|
9 |
os.environ["MODEL"] = "gpt-4o"
|
|
|
10 |
|
11 |
launch()
|
|
|
4 |
from synthetic_dataset_generator import launch
|
5 |
|
6 |
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
7 |
+
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/"
|
8 |
os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")
|
9 |
os.environ["MODEL"] = "gpt-4o"
|
10 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
|
11 |
|
12 |
launch()
|
examples/tgi_or_hf_dedicated.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install synthetic-dataset-generator
|
2 |
+
import os
|
3 |
+
|
4 |
+
from synthetic_dataset_generator import launch
|
5 |
+
|
6 |
+
assert os.getenv("HF_TOKEN") # push the data to huggingface
|
7 |
+
os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/"
|
8 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"
|
9 |
+
os.environ["TOKENIZER_ID"] = (
|
10 |
+
"meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
|
11 |
+
)
|
12 |
+
os.environ["MODEL"] = None # model is linked to endpoint
|
13 |
+
|
14 |
+
launch()
|
pdm.lock
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
groups = ["default"]
|
6 |
strategy = ["inherit_metadata"]
|
7 |
lock_version = "4.5.0"
|
8 |
-
content_hash = "sha256:
|
9 |
|
10 |
[[metadata.targets]]
|
11 |
requires_python = ">=3.10,<3.13"
|
@@ -491,8 +491,11 @@ files = [
|
|
491 |
|
492 |
[[package]]
|
493 |
name = "distilabel"
|
494 |
-
version = "1.
|
495 |
requires_python = ">=3.9"
|
|
|
|
|
|
|
496 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
497 |
groups = ["default"]
|
498 |
dependencies = [
|
@@ -512,30 +515,30 @@ dependencies = [
|
|
512 |
"typer>=0.9.0",
|
513 |
"universal-pathlib>=0.2.2",
|
514 |
]
|
515 |
-
files = [
|
516 |
-
{file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
|
517 |
-
{file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
|
518 |
-
]
|
519 |
|
520 |
[[package]]
|
521 |
name = "distilabel"
|
522 |
-
version = "1.
|
523 |
-
extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"]
|
524 |
requires_python = ">=3.9"
|
|
|
|
|
|
|
525 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
526 |
groups = ["default"]
|
527 |
dependencies = [
|
528 |
"argilla>=2.0.0",
|
529 |
-
"distilabel
|
530 |
"huggingface-hub>=0.22.0",
|
531 |
"instructor>=1.2.3",
|
532 |
"ipython",
|
|
|
533 |
"numba>=0.54.0",
|
|
|
|
|
534 |
"outlines>=0.0.40",
|
535 |
-
|
536 |
-
|
537 |
-
{file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
|
538 |
-
{file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
|
539 |
]
|
540 |
|
541 |
[[package]]
|
@@ -824,7 +827,7 @@ files = [
|
|
824 |
|
825 |
[[package]]
|
826 |
name = "httpx"
|
827 |
-
version = "0.
|
828 |
requires_python = ">=3.8"
|
829 |
summary = "The next generation HTTP client."
|
830 |
groups = ["default"]
|
@@ -833,10 +836,11 @@ dependencies = [
|
|
833 |
"certifi",
|
834 |
"httpcore==1.*",
|
835 |
"idna",
|
|
|
836 |
]
|
837 |
files = [
|
838 |
-
{file = "httpx-0.
|
839 |
-
{file = "httpx-0.
|
840 |
]
|
841 |
|
842 |
[[package]]
|
@@ -1068,6 +1072,22 @@ files = [
|
|
1068 |
{file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
|
1069 |
]
|
1070 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
[[package]]
|
1072 |
name = "llvmlite"
|
1073 |
version = "0.43.0"
|
@@ -1538,6 +1558,21 @@ files = [
|
|
1538 |
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
|
1539 |
]
|
1540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1541 |
[[package]]
|
1542 |
name = "openai"
|
1543 |
version = "1.57.4"
|
|
|
5 |
groups = ["default"]
|
6 |
strategy = ["inherit_metadata"]
|
7 |
lock_version = "4.5.0"
|
8 |
+
content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350"
|
9 |
|
10 |
[[metadata.targets]]
|
11 |
requires_python = ">=3.10,<3.13"
|
|
|
491 |
|
492 |
[[package]]
|
493 |
name = "distilabel"
|
494 |
+
version = "1.5.0"
|
495 |
requires_python = ">=3.9"
|
496 |
+
git = "https://github.com/argilla-io/distilabel.git"
|
497 |
+
ref = "feat/add-magpie-support-llama-cpp-ollama"
|
498 |
+
revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
|
499 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
500 |
groups = ["default"]
|
501 |
dependencies = [
|
|
|
515 |
"typer>=0.9.0",
|
516 |
"universal-pathlib>=0.2.2",
|
517 |
]
|
|
|
|
|
|
|
|
|
518 |
|
519 |
[[package]]
|
520 |
name = "distilabel"
|
521 |
+
version = "1.5.0"
|
522 |
+
extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"]
|
523 |
requires_python = ">=3.9"
|
524 |
+
git = "https://github.com/argilla-io/distilabel.git"
|
525 |
+
ref = "feat/add-magpie-support-llama-cpp-ollama"
|
526 |
+
revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
|
527 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
528 |
groups = ["default"]
|
529 |
dependencies = [
|
530 |
"argilla>=2.0.0",
|
531 |
+
"distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
|
532 |
"huggingface-hub>=0.22.0",
|
533 |
"instructor>=1.2.3",
|
534 |
"ipython",
|
535 |
+
"llama-cpp-python>=0.2.0",
|
536 |
"numba>=0.54.0",
|
537 |
+
"ollama>=0.1.7",
|
538 |
+
"openai>=1.0.0",
|
539 |
"outlines>=0.0.40",
|
540 |
+
"torch>=2.0.0",
|
541 |
+
"transformers>=4.34.1",
|
|
|
|
|
542 |
]
|
543 |
|
544 |
[[package]]
|
|
|
827 |
|
828 |
[[package]]
|
829 |
name = "httpx"
|
830 |
+
version = "0.27.2"
|
831 |
requires_python = ">=3.8"
|
832 |
summary = "The next generation HTTP client."
|
833 |
groups = ["default"]
|
|
|
836 |
"certifi",
|
837 |
"httpcore==1.*",
|
838 |
"idna",
|
839 |
+
"sniffio",
|
840 |
]
|
841 |
files = [
|
842 |
+
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
|
843 |
+
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
|
844 |
]
|
845 |
|
846 |
[[package]]
|
|
|
1072 |
{file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
|
1073 |
]
|
1074 |
|
1075 |
+
[[package]]
|
1076 |
+
name = "llama-cpp-python"
|
1077 |
+
version = "0.3.5"
|
1078 |
+
requires_python = ">=3.8"
|
1079 |
+
summary = "Python bindings for the llama.cpp library"
|
1080 |
+
groups = ["default"]
|
1081 |
+
dependencies = [
|
1082 |
+
"diskcache>=5.6.1",
|
1083 |
+
"jinja2>=2.11.3",
|
1084 |
+
"numpy>=1.20.0",
|
1085 |
+
"typing-extensions>=4.5.0",
|
1086 |
+
]
|
1087 |
+
files = [
|
1088 |
+
{file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"},
|
1089 |
+
]
|
1090 |
+
|
1091 |
[[package]]
|
1092 |
name = "llvmlite"
|
1093 |
version = "0.43.0"
|
|
|
1558 |
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
|
1559 |
]
|
1560 |
|
1561 |
+
[[package]]
|
1562 |
+
name = "ollama"
|
1563 |
+
version = "0.4.4"
|
1564 |
+
requires_python = "<4.0,>=3.8"
|
1565 |
+
summary = "The official Python client for Ollama."
|
1566 |
+
groups = ["default"]
|
1567 |
+
dependencies = [
|
1568 |
+
"httpx<0.28.0,>=0.27.0",
|
1569 |
+
"pydantic<3.0.0,>=2.9.0",
|
1570 |
+
]
|
1571 |
+
files = [
|
1572 |
+
{file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"},
|
1573 |
+
{file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"},
|
1574 |
+
]
|
1575 |
+
|
1576 |
[[package]]
|
1577 |
name = "openai"
|
1578 |
version = "1.57.4"
|
pyproject.toml
CHANGED
@@ -18,7 +18,7 @@ readme = "README.md"
|
|
18 |
license = {text = "Apache 2"}
|
19 |
|
20 |
dependencies = [
|
21 |
-
"distilabel[hf-inference-endpoints,
|
22 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
23 |
"transformers>=4.44.2,<5.0.0",
|
24 |
"sentence-transformers>=3.2.0,<4.0.0",
|
|
|
18 |
license = {text = "Apache 2"}
|
19 |
|
20 |
dependencies = [
|
21 |
+
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
|
22 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
23 |
"transformers>=4.44.2,<5.0.0",
|
24 |
"sentence-transformers>=3.2.0,<4.0.0",
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -121,7 +121,7 @@ def generate_dataset(
|
|
121 |
{
|
122 |
"instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
|
123 |
}
|
124 |
-
for i in range(int(num_rows /
|
125 |
]
|
126 |
batch = list(prompt_rewriter.process(inputs=inputs))
|
127 |
prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
|
|
|
121 |
{
|
122 |
"instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
|
123 |
}
|
124 |
+
for i in range(int(num_rows / 100))
|
125 |
]
|
126 |
batch = list(prompt_rewriter.process(inputs=inputs))
|
127 |
prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -178,26 +178,41 @@ def generate_dataset(
|
|
178 |
|
179 |
dataframe = pd.DataFrame(distiset_results)
|
180 |
if multi_label:
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
193 |
dataframe = dataframe[dataframe["labels"].notna()]
|
194 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
196 |
-
dataframe["label"] = dataframe["label"].apply(
|
197 |
-
lambda x: x.lower().strip()
|
198 |
-
if x and x.lower().strip() in labels
|
199 |
-
else random.choice(labels)
|
200 |
-
)
|
201 |
dataframe = dataframe[dataframe["text"].notna()]
|
202 |
|
203 |
progress(1.0, desc="Dataset created")
|
|
|
178 |
|
179 |
dataframe = pd.DataFrame(distiset_results)
|
180 |
if multi_label:
|
181 |
+
|
182 |
+
def _validate_labels(x):
|
183 |
+
if isinstance(x, str): # single label
|
184 |
+
return [x.lower().strip()]
|
185 |
+
elif isinstance(x, list): # multiple labels
|
186 |
+
return [
|
187 |
+
label.lower().strip()
|
188 |
+
for label in x
|
189 |
+
if label.lower().strip() in labels
|
190 |
+
]
|
191 |
+
else:
|
192 |
+
return [random.choice(labels)]
|
193 |
+
|
194 |
+
dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
|
195 |
dataframe = dataframe[dataframe["labels"].notna()]
|
196 |
else:
|
197 |
+
|
198 |
+
def _validate_labels(x):
|
199 |
+
if isinstance(x, str) and x.lower().strip() in labels:
|
200 |
+
return x.lower().strip()
|
201 |
+
elif isinstance(x, list):
|
202 |
+
options = [
|
203 |
+
label.lower().strip()
|
204 |
+
for label in x
|
205 |
+
if isinstance(label, str) and label.lower().strip() in labels
|
206 |
+
]
|
207 |
+
if options:
|
208 |
+
return random.choice(options)
|
209 |
+
else:
|
210 |
+
return random.choice(labels)
|
211 |
+
else:
|
212 |
+
return random.choice(labels)
|
213 |
+
|
214 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
215 |
+
dataframe["label"] = dataframe["label"].apply(_validate_labels)
|
|
|
|
|
|
|
|
|
216 |
dataframe = dataframe[dataframe["text"].notna()]
|
217 |
|
218 |
progress(1.0, desc="Dataset created")
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -7,39 +7,63 @@ import argilla as rg
|
|
7 |
TEXTCAT_TASK = "text_classification"
|
8 |
SFT_TASK = "supervised_fine_tuning"
|
9 |
|
10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
if not HF_TOKEN:
|
13 |
raise ValueError(
|
14 |
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
|
15 |
)
|
16 |
|
17 |
-
# Inference
|
18 |
-
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
19 |
-
MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
|
20 |
-
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
21 |
-
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
22 |
-
BASE_URL = os.getenv("BASE_URL", default=None)
|
23 |
-
|
24 |
_API_KEY = os.getenv("API_KEY")
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
]
|
31 |
API_KEYS = [token for token in API_KEYS if token]
|
32 |
|
33 |
# Determine if SFT is available
|
34 |
SFT_AVAILABLE = False
|
35 |
llama_options = ["llama3", "llama-3", "llama 3"]
|
36 |
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
|
37 |
-
|
|
|
38 |
SFT_AVAILABLE = True
|
39 |
-
passed_pre_query_template
|
40 |
-
if passed_pre_query_template.lower() in llama_options:
|
41 |
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
|
42 |
-
elif passed_pre_query_template
|
43 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
44 |
else:
|
45 |
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
|
@@ -54,12 +78,12 @@ elif MODEL.lower() in qwen_options or any(
|
|
54 |
SFT_AVAILABLE = True
|
55 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
56 |
|
57 |
-
if
|
58 |
SFT_AVAILABLE = False
|
59 |
|
60 |
if not SFT_AVAILABLE:
|
61 |
warnings.warn(
|
62 |
-
|
63 |
)
|
64 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
65 |
|
@@ -67,11 +91,12 @@ if not SFT_AVAILABLE:
|
|
67 |
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
|
68 |
|
69 |
# Argilla
|
70 |
-
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
|
77 |
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
|
|
|
7 |
TEXTCAT_TASK = "text_classification"
|
8 |
SFT_TASK = "supervised_fine_tuning"
|
9 |
|
10 |
+
# Inference
|
11 |
+
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
12 |
+
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
13 |
+
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
14 |
+
|
15 |
+
# Models
|
16 |
+
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
17 |
+
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
18 |
+
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
19 |
+
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
|
20 |
+
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
|
21 |
+
if HUGGINGFACE_BASE_URL and MODEL:
|
22 |
+
raise ValueError(
|
23 |
+
"`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
|
24 |
+
)
|
25 |
+
if OPENAI_BASE_URL or OLLAMA_BASE_URL:
|
26 |
+
if not MODEL:
|
27 |
+
raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
# Check if multiple base URLs are provided
|
32 |
+
base_urls = [
|
33 |
+
url for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL] if url
|
34 |
+
]
|
35 |
+
if len(base_urls) > 1:
|
36 |
+
raise ValueError(
|
37 |
+
f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
|
38 |
+
)
|
39 |
+
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL
|
40 |
+
|
41 |
+
|
42 |
+
# API Keys
|
43 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
44 |
if not HF_TOKEN:
|
45 |
raise ValueError(
|
46 |
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
|
47 |
)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
_API_KEY = os.getenv("API_KEY")
|
50 |
+
API_KEYS = (
|
51 |
+
[_API_KEY]
|
52 |
+
if _API_KEY
|
53 |
+
else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
|
54 |
+
)
|
|
|
55 |
API_KEYS = [token for token in API_KEYS if token]
|
56 |
|
57 |
# Determine if SFT is available
|
58 |
SFT_AVAILABLE = False
|
59 |
llama_options = ["llama3", "llama-3", "llama 3"]
|
60 |
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
|
61 |
+
|
62 |
+
if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
|
63 |
SFT_AVAILABLE = True
|
64 |
+
if passed_pre_query_template in llama_options:
|
|
|
65 |
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
|
66 |
+
elif passed_pre_query_template in qwen_options:
|
67 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
68 |
else:
|
69 |
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
|
|
|
78 |
SFT_AVAILABLE = True
|
79 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
80 |
|
81 |
+
if OPENAI_BASE_URL:
|
82 |
SFT_AVAILABLE = False
|
83 |
|
84 |
if not SFT_AVAILABLE:
|
85 |
warnings.warn(
|
86 |
+
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`."
|
87 |
)
|
88 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
89 |
|
|
|
91 |
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
|
92 |
|
93 |
# Argilla
|
94 |
+
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
|
95 |
+
"ARGILLA_API_URL_SDG_REVIEWER"
|
96 |
+
)
|
97 |
+
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
|
98 |
+
"ARGILLA_API_KEY_SDG_REVIEWER"
|
99 |
+
)
|
100 |
|
101 |
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
|
102 |
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -1,4 +1,15 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
TOKEN_INDEX = 0
|
4 |
|
@@ -8,3 +19,71 @@ def _get_next_api_key():
|
|
8 |
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
|
9 |
TOKEN_INDEX += 1
|
10 |
return api_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
|
3 |
+
|
4 |
+
from synthetic_dataset_generator.constants import (
|
5 |
+
API_KEYS,
|
6 |
+
HUGGINGFACE_BASE_URL,
|
7 |
+
MAGPIE_PRE_QUERY_TEMPLATE,
|
8 |
+
MODEL,
|
9 |
+
OLLAMA_BASE_URL,
|
10 |
+
OPENAI_BASE_URL,
|
11 |
+
TOKENIZER_ID,
|
12 |
+
)
|
13 |
|
14 |
TOKEN_INDEX = 0
|
15 |
|
|
|
19 |
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
|
20 |
TOKEN_INDEX += 1
|
21 |
return api_key
|
22 |
+
|
23 |
+
|
24 |
+
def _get_llm(use_magpie_template=False, **kwargs):
|
25 |
+
if OPENAI_BASE_URL:
|
26 |
+
llm = OpenAILLM(
|
27 |
+
model=MODEL,
|
28 |
+
base_url=OPENAI_BASE_URL,
|
29 |
+
api_key=_get_next_api_key(),
|
30 |
+
**kwargs,
|
31 |
+
)
|
32 |
+
if "generation_kwargs" in kwargs:
|
33 |
+
if "stop_sequences" in kwargs["generation_kwargs"]:
|
34 |
+
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
|
35 |
+
"stop_sequences"
|
36 |
+
]
|
37 |
+
del kwargs["generation_kwargs"]["stop_sequences"]
|
38 |
+
if "do_sample" in kwargs["generation_kwargs"]:
|
39 |
+
del kwargs["generation_kwargs"]["do_sample"]
|
40 |
+
elif OLLAMA_BASE_URL:
|
41 |
+
if "generation_kwargs" in kwargs:
|
42 |
+
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
43 |
+
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
44 |
+
"generation_kwargs"
|
45 |
+
]["max_new_tokens"]
|
46 |
+
del kwargs["generation_kwargs"]["max_new_tokens"]
|
47 |
+
if "stop_sequences" in kwargs["generation_kwargs"]:
|
48 |
+
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
|
49 |
+
"stop_sequences"
|
50 |
+
]
|
51 |
+
del kwargs["generation_kwargs"]["stop_sequences"]
|
52 |
+
if "do_sample" in kwargs["generation_kwargs"]:
|
53 |
+
del kwargs["generation_kwargs"]["do_sample"]
|
54 |
+
options = kwargs["generation_kwargs"]
|
55 |
+
del kwargs["generation_kwargs"]
|
56 |
+
kwargs["generation_kwargs"] = {}
|
57 |
+
kwargs["generation_kwargs"]["options"] = options
|
58 |
+
llm = OllamaLLM(
|
59 |
+
model=MODEL,
|
60 |
+
host=OLLAMA_BASE_URL,
|
61 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
62 |
+
**kwargs,
|
63 |
+
)
|
64 |
+
elif HUGGINGFACE_BASE_URL:
|
65 |
+
kwargs["generation_kwargs"]["do_sample"] = True
|
66 |
+
llm = InferenceEndpointsLLM(
|
67 |
+
api_key=_get_next_api_key(),
|
68 |
+
base_url=HUGGINGFACE_BASE_URL,
|
69 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
70 |
+
**kwargs,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
llm = InferenceEndpointsLLM(
|
74 |
+
api_key=_get_next_api_key(),
|
75 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
76 |
+
model_id=MODEL,
|
77 |
+
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
78 |
+
**kwargs,
|
79 |
+
)
|
80 |
+
|
81 |
+
return llm
|
82 |
+
|
83 |
+
|
84 |
+
try:
|
85 |
+
llm = _get_llm()
|
86 |
+
llm.load()
|
87 |
+
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
|
88 |
+
except Exception as e:
|
89 |
+
gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from distilabel.llms import InferenceEndpointsLLM
|
2 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
3 |
|
4 |
from synthetic_dataset_generator.constants import (
|
@@ -7,7 +6,7 @@ from synthetic_dataset_generator.constants import (
|
|
7 |
MAX_NUM_TOKENS,
|
8 |
MODEL,
|
9 |
)
|
10 |
-
from synthetic_dataset_generator.pipelines.base import
|
11 |
|
12 |
INFORMATION_SEEKING_PROMPT = (
|
13 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
@@ -149,18 +148,13 @@ def _get_output_mappings(num_turns):
|
|
149 |
|
150 |
|
151 |
def get_prompt_generator():
|
|
|
|
|
|
|
|
|
|
|
152 |
prompt_generator = TextGeneration(
|
153 |
-
llm=
|
154 |
-
api_key=_get_next_api_key(),
|
155 |
-
model_id=MODEL,
|
156 |
-
tokenizer_id=MODEL,
|
157 |
-
base_url=BASE_URL,
|
158 |
-
generation_kwargs={
|
159 |
-
"temperature": 0.8,
|
160 |
-
"max_new_tokens": MAX_NUM_TOKENS,
|
161 |
-
"do_sample": True,
|
162 |
-
},
|
163 |
-
),
|
164 |
system_prompt=PROMPT_CREATION_PROMPT,
|
165 |
use_system_prompt=True,
|
166 |
)
|
@@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
172 |
input_mappings = _get_output_mappings(num_turns)
|
173 |
output_mappings = input_mappings.copy()
|
174 |
if num_turns == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
magpie_generator = Magpie(
|
176 |
-
llm=
|
177 |
-
|
178 |
-
tokenizer_id=MODEL,
|
179 |
-
base_url=BASE_URL,
|
180 |
-
api_key=_get_next_api_key(),
|
181 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
182 |
-
|
183 |
-
"temperature": temperature,
|
184 |
-
"do_sample": True,
|
185 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
|
186 |
-
"stop_sequences": _STOP_SEQUENCES,
|
187 |
-
},
|
188 |
),
|
189 |
n_turns=num_turns,
|
190 |
output_mappings=output_mappings,
|
191 |
only_instruction=True,
|
192 |
)
|
193 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
magpie_generator = Magpie(
|
195 |
-
llm=
|
196 |
-
|
197 |
-
tokenizer_id=MODEL,
|
198 |
-
base_url=BASE_URL,
|
199 |
-
api_key=_get_next_api_key(),
|
200 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
201 |
-
|
202 |
-
"temperature": temperature,
|
203 |
-
"do_sample": True,
|
204 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
205 |
-
"stop_sequences": _STOP_SEQUENCES,
|
206 |
-
},
|
207 |
),
|
208 |
end_with_user=True,
|
209 |
n_turns=num_turns,
|
@@ -214,50 +204,33 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
214 |
|
215 |
|
216 |
def get_prompt_rewriter():
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
base_url=BASE_URL,
|
222 |
-
api_key=_get_next_api_key(),
|
223 |
-
generation_kwargs={
|
224 |
-
"temperature": 1,
|
225 |
-
},
|
226 |
-
),
|
227 |
-
)
|
228 |
prompt_rewriter.load()
|
229 |
return prompt_rewriter
|
230 |
|
231 |
|
232 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
233 |
if num_turns == 1:
|
|
|
|
|
|
|
|
|
234 |
response_generator = TextGeneration(
|
235 |
-
llm=
|
236 |
-
model_id=MODEL,
|
237 |
-
tokenizer_id=MODEL,
|
238 |
-
base_url=BASE_URL,
|
239 |
-
api_key=_get_next_api_key(),
|
240 |
-
generation_kwargs={
|
241 |
-
"temperature": temperature,
|
242 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
243 |
-
},
|
244 |
-
),
|
245 |
system_prompt=system_prompt,
|
246 |
output_mappings={"generation": "completion"},
|
247 |
input_mappings={"instruction": "prompt"},
|
248 |
)
|
249 |
else:
|
|
|
|
|
|
|
|
|
250 |
response_generator = ChatGeneration(
|
251 |
-
llm=
|
252 |
-
model_id=MODEL,
|
253 |
-
tokenizer_id=MODEL,
|
254 |
-
base_url=BASE_URL,
|
255 |
-
api_key=_get_next_api_key(),
|
256 |
-
generation_kwargs={
|
257 |
-
"temperature": temperature,
|
258 |
-
"max_new_tokens": MAX_NUM_TOKENS,
|
259 |
-
},
|
260 |
-
),
|
261 |
output_mappings={"generation": "completion"},
|
262 |
input_mappings={"conversation": "messages"},
|
263 |
)
|
@@ -293,7 +266,7 @@ with Pipeline(name="sft") as pipeline:
|
|
293 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
294 |
"stop_sequences": {_STOP_SEQUENCES}
|
295 |
}},
|
296 |
-
api_key=os.environ["
|
297 |
),
|
298 |
n_turns={num_turns},
|
299 |
num_rows={num_rows},
|
|
|
|
|
1 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
2 |
|
3 |
from synthetic_dataset_generator.constants import (
|
|
|
6 |
MAX_NUM_TOKENS,
|
7 |
MODEL,
|
8 |
)
|
9 |
+
from synthetic_dataset_generator.pipelines.base import _get_llm
|
10 |
|
11 |
INFORMATION_SEEKING_PROMPT = (
|
12 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
|
|
148 |
|
149 |
|
150 |
def get_prompt_generator():
|
151 |
+
generation_kwargs = {
|
152 |
+
"temperature": 0.8,
|
153 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
154 |
+
"do_sample": True,
|
155 |
+
}
|
156 |
prompt_generator = TextGeneration(
|
157 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
system_prompt=PROMPT_CREATION_PROMPT,
|
159 |
use_system_prompt=True,
|
160 |
)
|
|
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
169 |
+
generation_kwargs = {
|
170 |
+
"temperature": temperature,
|
171 |
+
"do_sample": True,
|
172 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
|
173 |
+
"stop_sequences": _STOP_SEQUENCES,
|
174 |
+
}
|
175 |
magpie_generator = Magpie(
|
176 |
+
llm=_get_llm(
|
177 |
+
generation_kwargs=generation_kwargs,
|
|
|
|
|
|
|
178 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
179 |
+
use_magpie_template=True,
|
|
|
|
|
|
|
|
|
|
|
180 |
),
|
181 |
n_turns=num_turns,
|
182 |
output_mappings=output_mappings,
|
183 |
only_instruction=True,
|
184 |
)
|
185 |
else:
|
186 |
+
generation_kwargs = {
|
187 |
+
"temperature": temperature,
|
188 |
+
"do_sample": True,
|
189 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
190 |
+
"stop_sequences": _STOP_SEQUENCES,
|
191 |
+
}
|
192 |
magpie_generator = Magpie(
|
193 |
+
llm=_get_llm(
|
194 |
+
generation_kwargs=generation_kwargs,
|
|
|
|
|
|
|
195 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
196 |
+
use_magpie_template=True,
|
|
|
|
|
|
|
|
|
|
|
197 |
),
|
198 |
end_with_user=True,
|
199 |
n_turns=num_turns,
|
|
|
204 |
|
205 |
|
206 |
def get_prompt_rewriter():
|
207 |
+
generation_kwargs = {
|
208 |
+
"temperature": 1,
|
209 |
+
}
|
210 |
+
prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
prompt_rewriter.load()
|
212 |
return prompt_rewriter
|
213 |
|
214 |
|
215 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
216 |
if num_turns == 1:
|
217 |
+
generation_kwargs = {
|
218 |
+
"temperature": temperature,
|
219 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
220 |
+
}
|
221 |
response_generator = TextGeneration(
|
222 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
system_prompt=system_prompt,
|
224 |
output_mappings={"generation": "completion"},
|
225 |
input_mappings={"instruction": "prompt"},
|
226 |
)
|
227 |
else:
|
228 |
+
generation_kwargs = {
|
229 |
+
"temperature": temperature,
|
230 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
231 |
+
}
|
232 |
response_generator = ChatGeneration(
|
233 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
output_mappings={"generation": "completion"},
|
235 |
input_mappings={"conversation": "messages"},
|
236 |
)
|
|
|
266 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
267 |
"stop_sequences": {_STOP_SEQUENCES}
|
268 |
}},
|
269 |
+
api_key=os.environ["API_KEY"],
|
270 |
),
|
271 |
n_turns={num_turns},
|
272 |
num_rows={num_rows},
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import random
|
2 |
from typing import List
|
3 |
|
4 |
-
from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
|
5 |
from distilabel.steps.tasks import (
|
6 |
GenerateTextClassificationData,
|
7 |
TextClassification,
|
@@ -9,8 +8,12 @@ from distilabel.steps.tasks import (
|
|
9 |
)
|
10 |
from pydantic import BaseModel, Field
|
11 |
|
12 |
-
from synthetic_dataset_generator.constants import
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
@@ -69,23 +72,10 @@ def get_prompt_generator():
|
|
69 |
"temperature": 0.8,
|
70 |
"max_new_tokens": MAX_NUM_TOKENS,
|
71 |
}
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
api_key=_get_next_api_key(),
|
77 |
-
structured_output=structured_output,
|
78 |
-
generation_kwargs=generation_kwargs,
|
79 |
-
)
|
80 |
-
else:
|
81 |
-
generation_kwargs["do_sample"] = True
|
82 |
-
llm = InferenceEndpointsLLM(
|
83 |
-
api_key=_get_next_api_key(),
|
84 |
-
model_id=MODEL,
|
85 |
-
base_url=BASE_URL,
|
86 |
-
structured_output=structured_output,
|
87 |
-
generation_kwargs=generation_kwargs,
|
88 |
-
)
|
89 |
|
90 |
prompt_generator = TextGeneration(
|
91 |
llm=llm,
|
@@ -103,21 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
103 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
104 |
"top_p": 0.95,
|
105 |
}
|
106 |
-
|
107 |
-
llm = OpenAILLM(
|
108 |
-
model=MODEL,
|
109 |
-
base_url=BASE_URL,
|
110 |
-
api_key=_get_next_api_key(),
|
111 |
-
generation_kwargs=generation_kwargs,
|
112 |
-
)
|
113 |
-
else:
|
114 |
-
generation_kwargs["do_sample"] = True
|
115 |
-
llm = InferenceEndpointsLLM(
|
116 |
-
model_id=MODEL,
|
117 |
-
base_url=BASE_URL,
|
118 |
-
api_key=_get_next_api_key(),
|
119 |
-
generation_kwargs=generation_kwargs,
|
120 |
-
)
|
121 |
|
122 |
textcat_generator = GenerateTextClassificationData(
|
123 |
llm=llm,
|
@@ -134,21 +110,7 @@ def get_labeller_generator(system_prompt, labels, multi_label):
|
|
134 |
"temperature": 0.01,
|
135 |
"max_new_tokens": MAX_NUM_TOKENS,
|
136 |
}
|
137 |
-
|
138 |
-
if BASE_URL:
|
139 |
-
llm = OpenAILLM(
|
140 |
-
model=MODEL,
|
141 |
-
base_url=BASE_URL,
|
142 |
-
api_key=_get_next_api_key(),
|
143 |
-
generation_kwargs=generation_kwargs,
|
144 |
-
)
|
145 |
-
else:
|
146 |
-
llm = InferenceEndpointsLLM(
|
147 |
-
model_id=MODEL,
|
148 |
-
base_url=BASE_URL,
|
149 |
-
api_key=_get_next_api_key(),
|
150 |
-
generation_kwargs=generation_kwargs,
|
151 |
-
)
|
152 |
|
153 |
labeller_generator = TextClassification(
|
154 |
llm=llm,
|
|
|
1 |
import random
|
2 |
from typing import List
|
3 |
|
|
|
4 |
from distilabel.steps.tasks import (
|
5 |
GenerateTextClassificationData,
|
6 |
TextClassification,
|
|
|
8 |
)
|
9 |
from pydantic import BaseModel, Field
|
10 |
|
11 |
+
from synthetic_dataset_generator.constants import (
|
12 |
+
BASE_URL,
|
13 |
+
MAX_NUM_TOKENS,
|
14 |
+
MODEL,
|
15 |
+
)
|
16 |
+
from synthetic_dataset_generator.pipelines.base import _get_llm
|
17 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
18 |
|
19 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
|
|
72 |
"temperature": 0.8,
|
73 |
"max_new_tokens": MAX_NUM_TOKENS,
|
74 |
}
|
75 |
+
llm = _get_llm(
|
76 |
+
structured_output=structured_output,
|
77 |
+
generation_kwargs=generation_kwargs,
|
78 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
prompt_generator = TextGeneration(
|
81 |
llm=llm,
|
|
|
93 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
94 |
"top_p": 0.95,
|
95 |
}
|
96 |
+
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
textcat_generator = GenerateTextClassificationData(
|
99 |
llm=llm,
|
|
|
110 |
"temperature": 0.01,
|
111 |
"max_new_tokens": MAX_NUM_TOKENS,
|
112 |
}
|
113 |
+
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
labeller_generator = TextClassification(
|
116 |
llm=llm,
|