Upload folder using huggingface_hub
Browse files- dataset.py +7 -3
- metric.py +6 -1
- metrics.py +72 -47
- standard.py +3 -1
- task.py +11 -1
- version.py +1 -1
dataset.py
CHANGED
@@ -10,6 +10,7 @@ from .catalog import __file__ as _
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
|
|
13 |
from .dataset_utils import get_dataset_artifact
|
14 |
from .deprecation_utils import __file__ as _
|
15 |
from .dialog_operators import __file__ as _
|
@@ -19,11 +20,13 @@ from .file_utils import __file__ as _
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
|
|
27 |
from .logging_utils import get_logger
|
28 |
from .metric import __file__ as _
|
29 |
from .metric_utils import __file__ as _
|
@@ -37,6 +40,7 @@ from .random_utils import __file__ as _
|
|
37 |
from .recipe import __file__ as _
|
38 |
from .register import __file__ as _
|
39 |
from .schema import __file__ as _
|
|
|
40 |
from .settings_utils import get_constants
|
41 |
from .span_lableing_operators import __file__ as _
|
42 |
from .split_utils import __file__ as _
|
@@ -50,6 +54,7 @@ from .task import __file__ as _
|
|
50 |
from .templates import __file__ as _
|
51 |
from .text_utils import __file__ as _
|
52 |
from .type_utils import __file__ as _
|
|
|
53 |
from .utils import is_package_installed
|
54 |
from .validate import __file__ as _
|
55 |
from .version import __file__ as _
|
@@ -70,9 +75,8 @@ class Dataset(datasets.GeneratorBasedBuilder):
|
|
70 |
if is_package_installed("unitxt"):
|
71 |
verify_versions_compatibility("dataset", self.VERSION)
|
72 |
|
73 |
-
from unitxt.dataset_utils import
|
74 |
-
get_dataset_artifact as get_dataset_artifact_installed
|
75 |
-
)
|
76 |
|
77 |
logger.info("Loading with installed unitxt library...")
|
78 |
dataset = get_dataset_artifact_installed(self.config.name)
|
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
13 |
+
from .dataset_utils import __file__ as _
|
14 |
from .dataset_utils import get_dataset_artifact
|
15 |
from .deprecation_utils import __file__ as _
|
16 |
from .dialog_operators import __file__ as _
|
|
|
20 |
from .formats import __file__ as _
|
21 |
from .fusion import __file__ as _
|
22 |
from .generator_utils import __file__ as _
|
23 |
+
from .hf_utils import __file__ as _
|
24 |
from .hf_utils import verify_versions_compatibility
|
25 |
from .inference import __file__ as _
|
26 |
from .instructions import __file__ as _
|
27 |
from .llm_as_judge import __file__ as _
|
28 |
from .loaders import __file__ as _
|
29 |
+
from .logging_utils import __file__ as _
|
30 |
from .logging_utils import get_logger
|
31 |
from .metric import __file__ as _
|
32 |
from .metric_utils import __file__ as _
|
|
|
40 |
from .recipe import __file__ as _
|
41 |
from .register import __file__ as _
|
42 |
from .schema import __file__ as _
|
43 |
+
from .settings_utils import __file__ as _
|
44 |
from .settings_utils import get_constants
|
45 |
from .span_lableing_operators import __file__ as _
|
46 |
from .split_utils import __file__ as _
|
|
|
54 |
from .templates import __file__ as _
|
55 |
from .text_utils import __file__ as _
|
56 |
from .type_utils import __file__ as _
|
57 |
+
from .utils import __file__ as _
|
58 |
from .utils import is_package_installed
|
59 |
from .validate import __file__ as _
|
60 |
from .version import __file__ as _
|
|
|
75 |
if is_package_installed("unitxt"):
|
76 |
verify_versions_compatibility("dataset", self.VERSION)
|
77 |
|
78 |
+
from unitxt.dataset_utils import \
|
79 |
+
get_dataset_artifact as get_dataset_artifact_installed
|
|
|
80 |
|
81 |
logger.info("Loading with installed unitxt library...")
|
82 |
dataset = get_dataset_artifact_installed(self.config.name)
|
metric.py
CHANGED
@@ -19,13 +19,16 @@ from .file_utils import __file__ as _
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
27 |
from .logging_utils import __file__ as _
|
28 |
-
from .metric_utils import UNITXT_METRIC_SCHEMA
|
|
|
|
|
29 |
from .metrics import __file__ as _
|
30 |
from .normalizers import __file__ as _
|
31 |
from .operator import __file__ as _
|
@@ -36,6 +39,7 @@ from .random_utils import __file__ as _
|
|
36 |
from .recipe import __file__ as _
|
37 |
from .register import __file__ as _
|
38 |
from .schema import __file__ as _
|
|
|
39 |
from .settings_utils import get_constants
|
40 |
from .span_lableing_operators import __file__ as _
|
41 |
from .split_utils import __file__ as _
|
@@ -49,6 +53,7 @@ from .task import __file__ as _
|
|
49 |
from .templates import __file__ as _
|
50 |
from .text_utils import __file__ as _
|
51 |
from .type_utils import __file__ as _
|
|
|
52 |
from .utils import is_package_installed
|
53 |
from .validate import __file__ as _
|
54 |
from .version import __file__ as _
|
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
22 |
+
from .hf_utils import __file__ as _
|
23 |
from .hf_utils import verify_versions_compatibility
|
24 |
from .inference import __file__ as _
|
25 |
from .instructions import __file__ as _
|
26 |
from .llm_as_judge import __file__ as _
|
27 |
from .loaders import __file__ as _
|
28 |
from .logging_utils import __file__ as _
|
29 |
+
from .metric_utils import UNITXT_METRIC_SCHEMA
|
30 |
+
from .metric_utils import __file__ as _
|
31 |
+
from .metric_utils import _compute
|
32 |
from .metrics import __file__ as _
|
33 |
from .normalizers import __file__ as _
|
34 |
from .operator import __file__ as _
|
|
|
39 |
from .recipe import __file__ as _
|
40 |
from .register import __file__ as _
|
41 |
from .schema import __file__ as _
|
42 |
+
from .settings_utils import __file__ as _
|
43 |
from .settings_utils import get_constants
|
44 |
from .span_lableing_operators import __file__ as _
|
45 |
from .split_utils import __file__ as _
|
|
|
53 |
from .templates import __file__ as _
|
54 |
from .text_utils import __file__ as _
|
55 |
from .type_utils import __file__ as _
|
56 |
+
from .utils import __file__ as _
|
57 |
from .utils import is_package_installed
|
58 |
from .validate import __file__ as _
|
59 |
from .version import __file__ as _
|
metrics.py
CHANGED
@@ -29,7 +29,7 @@ from .operators import CopyFields
|
|
29 |
from .random_utils import get_seed
|
30 |
from .settings_utils import get_settings
|
31 |
from .stream import MultiStream, Stream
|
32 |
-
from .type_utils import isoftype, parse_type_string
|
33 |
|
34 |
logger = get_logger()
|
35 |
settings = get_settings()
|
@@ -1261,17 +1261,28 @@ class F1Micro(F1):
|
|
1261 |
average = "micro"
|
1262 |
|
1263 |
|
1264 |
-
class F1Binary(
|
1265 |
"""Calculate f1 for a binary task, using 0.5 as the threshold in the case of float predictions."""
|
1266 |
|
1267 |
process_single_instances = False
|
1268 |
main_score = "f1_binary"
|
1269 |
-
average =
|
1270 |
-
pos_classes = {"1", "1.0", "yes", "true"}
|
1271 |
threshold = 0.5
|
|
|
|
|
|
|
|
|
1272 |
|
1273 |
-
def
|
1274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1275 |
|
1276 |
def compute(
|
1277 |
self,
|
@@ -1279,12 +1290,21 @@ class F1Binary(F1):
|
|
1279 |
predictions: List[str],
|
1280 |
task_data: List[Dict],
|
1281 |
) -> dict:
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
|
1286 |
-
|
1287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1288 |
|
1289 |
|
1290 |
class RecallBinary(F1Binary):
|
@@ -1538,7 +1558,7 @@ class KendallTauMetric(GlobalMetric):
|
|
1538 |
main_score = "kendalltau_b"
|
1539 |
variant = "b"
|
1540 |
process_single_instances = False
|
1541 |
-
prediction_type = "
|
1542 |
|
1543 |
_requirements_list: List[str] = ["scipy"]
|
1544 |
|
@@ -1555,8 +1575,6 @@ class KendallTauMetric(GlobalMetric):
|
|
1555 |
) -> dict:
|
1556 |
if isinstance(references[0], list):
|
1557 |
references = [reference[0] for reference in references]
|
1558 |
-
references = [to_float_or_default(r) for r in references]
|
1559 |
-
predictions = [to_float_or_default(p) for p in predictions]
|
1560 |
|
1561 |
kendall_results = self.kendalltau(references, predictions, variant=self.variant)
|
1562 |
corr = kendall_results.correlation
|
@@ -1602,7 +1620,7 @@ class RocAuc(GlobalMetric):
|
|
1602 |
process_single_instances = False
|
1603 |
_requirements_list: List[str] = ["sklearn"]
|
1604 |
single_reference_per_prediction = True
|
1605 |
-
prediction_type = "
|
1606 |
|
1607 |
def prepare(self):
|
1608 |
from sklearn import metrics
|
@@ -1618,8 +1636,6 @@ class RocAuc(GlobalMetric):
|
|
1618 |
) -> dict:
|
1619 |
if isinstance(references[0], list):
|
1620 |
references = [reference[0] for reference in references]
|
1621 |
-
references = [to_float_or_default(r) for r in references]
|
1622 |
-
predictions = [to_float_or_default(p) for p in predictions]
|
1623 |
|
1624 |
false_positive_rates, true_positive_rates, _ = self.roc_curve(
|
1625 |
y_true=references, y_score=predictions
|
@@ -3337,33 +3353,42 @@ class BinaryMaxF1(F1Binary):
|
|
3337 |
"""Calculate the maximal F1 and the decision threshold that achieves it for a binary task with float predictions."""
|
3338 |
|
3339 |
main_score = "max_f1_binary"
|
3340 |
-
prediction_type = str
|
3341 |
single_reference_per_prediction = True
|
3342 |
|
3343 |
def compute(
|
3344 |
self,
|
3345 |
-
references: List[List[
|
3346 |
-
predictions: List[List[
|
3347 |
task_data: List[Dict],
|
3348 |
) -> dict:
|
3349 |
-
float_predictions = [to_float_or_default(p) for p in predictions]
|
3350 |
-
|
3351 |
best_thr = -1
|
3352 |
best_f1 = -1
|
3353 |
-
|
|
|
|
|
3354 |
for thr in thrs:
|
3355 |
new_predictions = [
|
3356 |
-
|
3357 |
-
for float_prediction in
|
3358 |
-
]
|
3359 |
-
f1 = super().compute(references, new_predictions, task_data)[
|
3360 |
-
self.main_score
|
3361 |
]
|
|
|
|
|
|
|
3362 |
if f1 > best_f1:
|
3363 |
best_f1 = f1
|
3364 |
best_thr = thr
|
3365 |
|
3366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3367 |
|
3368 |
|
3369 |
class BinaryAccuracy(InstanceMetric):
|
@@ -3372,20 +3397,25 @@ class BinaryAccuracy(InstanceMetric):
|
|
3372 |
reduction_map = {"mean": ["accuracy_binary"]}
|
3373 |
main_score = "accuracy_binary"
|
3374 |
ci_scores = ["accuracy_binary"]
|
3375 |
-
pos_classes = {"1", "1.0", "yes", "true"}
|
3376 |
threshold = 0.5
|
3377 |
|
3378 |
-
prediction_type = "
|
3379 |
single_reference_per_prediction = True
|
3380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3381 |
def compute(
|
3382 |
-
self, references: List[
|
3383 |
) -> dict:
|
3384 |
-
|
3385 |
-
|
3386 |
-
references = ["1"] if references[0].lower() in self.pos_classes else ["0"]
|
3387 |
|
3388 |
-
result = {self.main_score: float(
|
3389 |
result["score"] = result[self.main_score]
|
3390 |
result["score_name"] = self.main_score
|
3391 |
return result
|
@@ -3396,9 +3426,7 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
3396 |
|
3397 |
process_single_instances = False
|
3398 |
main_score = "max_accuracy_binary"
|
3399 |
-
|
3400 |
-
|
3401 |
-
prediction_type = "str"
|
3402 |
single_reference_per_prediction = True
|
3403 |
|
3404 |
def compute(
|
@@ -3407,10 +3435,7 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
3407 |
predictions: List[str],
|
3408 |
task_data: List[Dict],
|
3409 |
) -> dict:
|
3410 |
-
|
3411 |
-
references = [
|
3412 |
-
["1"] if r[0].lower() in self.pos_classes else ["0"] for r in references
|
3413 |
-
]
|
3414 |
|
3415 |
# Sticking to the test >= thr, accuracy induced by threshold thr is the number of float predictions
|
3416 |
# that pass the test (are >= thr) and are paired with reference "1" plus the number of float predictions that
|
@@ -3421,8 +3446,8 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
3421 |
# the largest float predictions, to induce the partition into all-failing , none-passing.
|
3422 |
|
3423 |
fp = [
|
3424 |
-
(
|
3425 |
-
for i in range(len(
|
3426 |
]
|
3427 |
fp.sort()
|
3428 |
# each triplet above: float-prediction f; f's ordinal position in float_predictions, which is also
|
@@ -3436,7 +3461,7 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
3436 |
|
3437 |
current_thr = fp[0][0]
|
3438 |
# partition float_predictions into all-passing, none-failing
|
3439 |
-
current_acc = sum(r[0] ==
|
3440 |
# number of predictions that thr sends to the reference they are paired with
|
3441 |
|
3442 |
best_acc = current_acc
|
|
|
29 |
from .random_utils import get_seed
|
30 |
from .settings_utils import get_settings
|
31 |
from .stream import MultiStream, Stream
|
32 |
+
from .type_utils import isoftype, parse_type_string
|
33 |
|
34 |
logger = get_logger()
|
35 |
settings = get_settings()
|
|
|
1261 |
average = "micro"
|
1262 |
|
1263 |
|
1264 |
+
class F1Binary(GlobalMetric):
|
1265 |
"""Calculate f1 for a binary task, using 0.5 as the threshold in the case of float predictions."""
|
1266 |
|
1267 |
process_single_instances = False
|
1268 |
main_score = "f1_binary"
|
1269 |
+
average = None
|
|
|
1270 |
threshold = 0.5
|
1271 |
+
prediction_type = "Union[float, int]"
|
1272 |
+
_metric = None
|
1273 |
+
metric = "f1"
|
1274 |
+
single_reference_per_prediction = True
|
1275 |
|
1276 |
+
def prepare(self):
|
1277 |
+
super().prepare()
|
1278 |
+
self._metric = evaluate.load(self.metric)
|
1279 |
+
|
1280 |
+
def _validate_reference(self, reference):
|
1281 |
+
super()._validate_reference(reference)
|
1282 |
+
assert reference[0] in [
|
1283 |
+
0,
|
1284 |
+
1,
|
1285 |
+
], f"all references of {self.main_score} must by 0 or 1"
|
1286 |
|
1287 |
def compute(
|
1288 |
self,
|
|
|
1290 |
predictions: List[str],
|
1291 |
task_data: List[Dict],
|
1292 |
) -> dict:
|
1293 |
+
flattened_int_references = [int(r[0]) for r in references]
|
1294 |
+
int_predictions = [int(p > self.threshold) for p in predictions]
|
1295 |
+
|
1296 |
+
result = self._metric.compute(
|
1297 |
+
references=flattened_int_references,
|
1298 |
+
predictions=int_predictions,
|
1299 |
+
labels=[0, 1],
|
1300 |
+
average=self.average,
|
1301 |
+
)
|
1302 |
+
if isinstance(result[self.metric], numpy.ndarray):
|
1303 |
+
return {
|
1304 |
+
self.main_score: result[self.metric][1],
|
1305 |
+
f"{self.main_score}_neg": result[self.metric][0],
|
1306 |
+
}
|
1307 |
+
return {self.main_score: result[self.metric]}
|
1308 |
|
1309 |
|
1310 |
class RecallBinary(F1Binary):
|
|
|
1558 |
main_score = "kendalltau_b"
|
1559 |
variant = "b"
|
1560 |
process_single_instances = False
|
1561 |
+
prediction_type = "float"
|
1562 |
|
1563 |
_requirements_list: List[str] = ["scipy"]
|
1564 |
|
|
|
1575 |
) -> dict:
|
1576 |
if isinstance(references[0], list):
|
1577 |
references = [reference[0] for reference in references]
|
|
|
|
|
1578 |
|
1579 |
kendall_results = self.kendalltau(references, predictions, variant=self.variant)
|
1580 |
corr = kendall_results.correlation
|
|
|
1620 |
process_single_instances = False
|
1621 |
_requirements_list: List[str] = ["sklearn"]
|
1622 |
single_reference_per_prediction = True
|
1623 |
+
prediction_type = "float"
|
1624 |
|
1625 |
def prepare(self):
|
1626 |
from sklearn import metrics
|
|
|
1636 |
) -> dict:
|
1637 |
if isinstance(references[0], list):
|
1638 |
references = [reference[0] for reference in references]
|
|
|
|
|
1639 |
|
1640 |
false_positive_rates, true_positive_rates, _ = self.roc_curve(
|
1641 |
y_true=references, y_score=predictions
|
|
|
3353 |
"""Calculate the maximal F1 and the decision threshold that achieves it for a binary task with float predictions."""
|
3354 |
|
3355 |
main_score = "max_f1_binary"
|
|
|
3356 |
single_reference_per_prediction = True
|
3357 |
|
3358 |
def compute(
|
3359 |
self,
|
3360 |
+
references: List[List[float]],
|
3361 |
+
predictions: List[List[float]],
|
3362 |
task_data: List[Dict],
|
3363 |
) -> dict:
|
|
|
|
|
3364 |
best_thr = -1
|
3365 |
best_f1 = -1
|
3366 |
+
best_thr_neg = -1
|
3367 |
+
best_f1_neg = -1
|
3368 |
+
thrs = {round(fp, 3) for fp in predictions}
|
3369 |
for thr in thrs:
|
3370 |
new_predictions = [
|
3371 |
+
1.0 if float_prediction >= thr else 0.0
|
3372 |
+
for float_prediction in predictions
|
|
|
|
|
|
|
3373 |
]
|
3374 |
+
f1_results = super().compute(references, new_predictions, task_data)
|
3375 |
+
|
3376 |
+
f1 = f1_results[self.main_score]
|
3377 |
if f1 > best_f1:
|
3378 |
best_f1 = f1
|
3379 |
best_thr = thr
|
3380 |
|
3381 |
+
f1_neg = f1_results[f"{self.main_score}_neg"]
|
3382 |
+
if f1_neg > best_f1_neg:
|
3383 |
+
best_f1_neg = f1_neg
|
3384 |
+
best_thr_neg = thr
|
3385 |
+
|
3386 |
+
return {
|
3387 |
+
self.main_score: best_f1,
|
3388 |
+
"best_thr_maxf1": best_thr,
|
3389 |
+
f"{self.main_score}_neg": best_f1_neg,
|
3390 |
+
"best_thr_maxf1_neg": best_thr_neg,
|
3391 |
+
}
|
3392 |
|
3393 |
|
3394 |
class BinaryAccuracy(InstanceMetric):
|
|
|
3397 |
reduction_map = {"mean": ["accuracy_binary"]}
|
3398 |
main_score = "accuracy_binary"
|
3399 |
ci_scores = ["accuracy_binary"]
|
|
|
3400 |
threshold = 0.5
|
3401 |
|
3402 |
+
prediction_type = "Union[float,int]"
|
3403 |
single_reference_per_prediction = True
|
3404 |
|
3405 |
+
def _validate_reference(self, reference):
|
3406 |
+
super()._validate_reference(reference)
|
3407 |
+
assert reference[0] in [
|
3408 |
+
0,
|
3409 |
+
1,
|
3410 |
+
], f"all references of {self.main_score} must by 0 or 1"
|
3411 |
+
|
3412 |
def compute(
|
3413 |
+
self, references: List[float], prediction: float, task_data: List[Dict]
|
3414 |
) -> dict:
|
3415 |
+
prediction = int(prediction > self.threshold)
|
3416 |
+
reference = int(references[0])
|
|
|
3417 |
|
3418 |
+
result = {self.main_score: float(prediction == reference)}
|
3419 |
result["score"] = result[self.main_score]
|
3420 |
result["score_name"] = self.main_score
|
3421 |
return result
|
|
|
3426 |
|
3427 |
process_single_instances = False
|
3428 |
main_score = "max_accuracy_binary"
|
3429 |
+
prediction_type = "Union[float,int]"
|
|
|
|
|
3430 |
single_reference_per_prediction = True
|
3431 |
|
3432 |
def compute(
|
|
|
3435 |
predictions: List[str],
|
3436 |
task_data: List[Dict],
|
3437 |
) -> dict:
|
3438 |
+
references = [[int(r[0])] for r in references]
|
|
|
|
|
|
|
3439 |
|
3440 |
# Sticking to the test >= thr, accuracy induced by threshold thr is the number of float predictions
|
3441 |
# that pass the test (are >= thr) and are paired with reference "1" plus the number of float predictions that
|
|
|
3446 |
# the largest float predictions, to induce the partition into all-failing , none-passing.
|
3447 |
|
3448 |
fp = [
|
3449 |
+
(predictions[i], i, -1 if references[i][0] == 1 else +1)
|
3450 |
+
for i in range(len(predictions))
|
3451 |
]
|
3452 |
fp.sort()
|
3453 |
# each triplet above: float-prediction f; f's ordinal position in float_predictions, which is also
|
|
|
3461 |
|
3462 |
current_thr = fp[0][0]
|
3463 |
# partition float_predictions into all-passing, none-failing
|
3464 |
+
current_acc = sum(r[0] == 1 for r in references)
|
3465 |
# number of predictions that thr sends to the reference they are paired with
|
3466 |
|
3467 |
best_acc = current_acc
|
standard.py
CHANGED
@@ -225,7 +225,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
225 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
226 |
self.processing.steps.append(self.augmentor)
|
227 |
|
228 |
-
if self.
|
229 |
self.processing.steps.append(
|
230 |
CreateDemosPool(
|
231 |
from_split=self.demos_taken_from,
|
@@ -234,6 +234,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
234 |
remove_targets_from_source_split=self.demos_removed_from_data,
|
235 |
)
|
236 |
)
|
|
|
|
|
237 |
if self.sampler is None:
|
238 |
if self.card.sampler is None:
|
239 |
raise ValueError(
|
|
|
225 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
226 |
self.processing.steps.append(self.augmentor)
|
227 |
|
228 |
+
if self.demos_pool_size is not None:
|
229 |
self.processing.steps.append(
|
230 |
CreateDemosPool(
|
231 |
from_split=self.demos_taken_from,
|
|
|
234 |
remove_targets_from_source_split=self.demos_removed_from_data,
|
235 |
)
|
236 |
)
|
237 |
+
|
238 |
+
if self.num_demos > 0:
|
239 |
if self.sampler is None:
|
240 |
if self.card.sampler is None:
|
241 |
raise ValueError(
|
task.py
CHANGED
@@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Union
|
|
3 |
from .artifact import fetch_artifact
|
4 |
from .logging_utils import get_logger
|
5 |
from .operator import StreamInstanceOperator
|
6 |
-
from .type_utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
class Tasker:
|
@@ -79,6 +85,10 @@ class FormTask(Tasker, StreamInstanceOperator):
|
|
79 |
prediction_type == metric_prediction_type
|
80 |
or prediction_type == Any
|
81 |
or metric_prediction_type == Any
|
|
|
|
|
|
|
|
|
82 |
):
|
83 |
continue
|
84 |
|
|
|
3 |
from .artifact import fetch_artifact
|
4 |
from .logging_utils import get_logger
|
5 |
from .operator import StreamInstanceOperator
|
6 |
+
from .type_utils import (
|
7 |
+
get_args,
|
8 |
+
get_origin,
|
9 |
+
isoftype,
|
10 |
+
parse_type_string,
|
11 |
+
verify_required_schema,
|
12 |
+
)
|
13 |
|
14 |
|
15 |
class Tasker:
|
|
|
85 |
prediction_type == metric_prediction_type
|
86 |
or prediction_type == Any
|
87 |
or metric_prediction_type == Any
|
88 |
+
or (
|
89 |
+
get_origin(metric_prediction_type) is Union
|
90 |
+
and prediction_type in get_args(metric_prediction_type)
|
91 |
+
)
|
92 |
):
|
93 |
continue
|
94 |
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.8.0"
|