Commit
·
1f8e434
1
Parent(s):
a7d1c2d
Update load_data.py
Browse files- load_data.py +15 -3
load_data.py
CHANGED
@@ -2,21 +2,33 @@ import sys
|
|
2 |
import time
|
3 |
import os
|
4 |
|
5 |
-
import argilla as rg
|
6 |
import pandas as pd
|
7 |
import requests
|
8 |
from datasets import load_dataset, concatenate_datasets
|
9 |
|
|
|
10 |
from argilla.listeners import listener
|
11 |
|
|
|
|
|
|
|
12 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
|
13 |
SOURCE_DATASET = "LEL-A/translated_german_alpaca"
|
|
|
|
|
14 |
RG_DATASET_NAME = "translated-german-alpaca"
|
|
|
|
|
15 |
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation")
|
16 |
|
|
|
|
|
|
|
17 |
@listener(
|
18 |
dataset=RG_DATASET_NAME,
|
19 |
-
query="status:Validated",
|
20 |
execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub`
|
21 |
)
|
22 |
def save_validated_to_hub(records, ctx):
|
@@ -60,7 +72,7 @@ class LoadDatasets:
|
|
60 |
records = rg.DatasetForTextClassification.from_datasets(dataset)
|
61 |
|
62 |
settings = rg.TextClassificationSettings(
|
63 |
-
label_schema=
|
64 |
)
|
65 |
|
66 |
print(f"Configuring dataset: {RG_DATASET_NAME}")
|
|
|
2 |
import time
|
3 |
import os
|
4 |
|
|
|
5 |
import pandas as pd
|
6 |
import requests
|
7 |
from datasets import load_dataset, concatenate_datasets
|
8 |
|
9 |
+
import argilla as rg
|
10 |
from argilla.listeners import listener
|
11 |
|
12 |
+
### Configuration section ###
|
13 |
+
|
14 |
+
# needed for pushing the validated data to HUB_DATASET_NAME
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
+
|
17 |
+
# The source dataset to read Alpaca translated examples
|
18 |
SOURCE_DATASET = "LEL-A/translated_german_alpaca"
|
19 |
+
|
20 |
+
# The name of the dataset in Argilla
|
21 |
RG_DATASET_NAME = "translated-german-alpaca"
|
22 |
+
|
23 |
+
# The name of the Hub dataset to push the validations every 20 min and keep the dataset synced
|
24 |
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation")
|
25 |
|
26 |
+
# The labels for the tasks (they can be extended if needed)
|
27 |
+
LABELS = ["BAD INSTRUCTION", "BAD INPUT", "BAD OUTPUT", "INAPPROPRIATE", "BIASED", "ALL GOOD"]
|
28 |
+
|
29 |
@listener(
|
30 |
dataset=RG_DATASET_NAME,
|
31 |
+
query="status:Validated",
|
32 |
execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub`
|
33 |
)
|
34 |
def save_validated_to_hub(records, ctx):
|
|
|
72 |
records = rg.DatasetForTextClassification.from_datasets(dataset)
|
73 |
|
74 |
settings = rg.TextClassificationSettings(
|
75 |
+
label_schema=LABELS
|
76 |
)
|
77 |
|
78 |
print(f"Configuring dataset: {RG_DATASET_NAME}")
|