Spaces:
Sleeping
Sleeping
import importlib | |
import os | |
import tempfile | |
from unittest import TestCase | |
import pytest | |
from datasets import DownloadConfig | |
import evaluate | |
from evaluate.loading import ( | |
CachedEvaluationModuleFactory, | |
HubEvaluationModuleFactory, | |
LocalEvaluationModuleFactory, | |
evaluation_module_factory, | |
) | |
from .utils import OfflineSimulationMode, offline | |
SAMPLE_METRIC_IDENTIFIER = "lvwerra/test" | |
METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__" | |
METRIC_LOADING_SCRIPT_CODE = """ | |
import evaluate | |
from evaluate import EvaluationModuleInfo | |
from datasets import Features, Value | |
class __DummyMetric1__(evaluate.EvaluationModule): | |
def _info(self): | |
return EvaluationModuleInfo(features=Features({"predictions": Value("int"), "references": Value("int")})) | |
def _compute(self, predictions, references): | |
return {"__dummy_metric1__": sum(int(p == r) for p, r in zip(predictions, references))} | |
""" | |
def metric_loading_script_dir(tmp_path): | |
script_name = METRIC_LOADING_SCRIPT_NAME | |
script_dir = tmp_path / script_name | |
script_dir.mkdir() | |
script_path = script_dir / f"{script_name}.py" | |
with open(script_path, "w") as f: | |
f.write(METRIC_LOADING_SCRIPT_CODE) | |
return str(script_dir) | |
class ModuleFactoryTest(TestCase): | |
def inject_fixtures(self, metric_loading_script_dir): | |
self._metric_loading_script_dir = metric_loading_script_dir | |
def setUp(self): | |
self.hf_modules_cache = tempfile.mkdtemp() | |
self.cache_dir = tempfile.mkdtemp() | |
self.download_config = DownloadConfig(cache_dir=self.cache_dir) | |
self.dynamic_modules_path = evaluate.loading.init_dynamic_modules( | |
name="test_datasets_modules_" + os.path.basename(self.hf_modules_cache), | |
hf_modules_cache=self.hf_modules_cache, | |
) | |
def test_HubEvaluationModuleFactory_with_internal_import(self): | |
# "squad_v2" requires additional imports (internal) | |
factory = HubEvaluationModuleFactory( | |
"evaluate-metric/squad_v2", | |
module_type="metric", | |
download_config=self.download_config, | |
dynamic_modules_path=self.dynamic_modules_path, | |
) | |
module_factory_result = factory.get_module() | |
assert importlib.import_module(module_factory_result.module_path) is not None | |
def test_HubEvaluationModuleFactory_with_external_import(self): | |
# "bleu" requires additional imports (external from github) | |
factory = HubEvaluationModuleFactory( | |
"evaluate-metric/bleu", | |
module_type="metric", | |
download_config=self.download_config, | |
dynamic_modules_path=self.dynamic_modules_path, | |
) | |
module_factory_result = factory.get_module() | |
assert importlib.import_module(module_factory_result.module_path) is not None | |
def test_HubEvaluationModuleFactoryWithScript(self): | |
factory = HubEvaluationModuleFactory( | |
SAMPLE_METRIC_IDENTIFIER, | |
download_config=self.download_config, | |
dynamic_modules_path=self.dynamic_modules_path, | |
) | |
module_factory_result = factory.get_module() | |
assert importlib.import_module(module_factory_result.module_path) is not None | |
def test_LocalMetricModuleFactory(self): | |
path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") | |
factory = LocalEvaluationModuleFactory( | |
path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |
module_factory_result = factory.get_module() | |
assert importlib.import_module(module_factory_result.module_path) is not None | |
def test_CachedMetricModuleFactory(self): | |
path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") | |
factory = LocalEvaluationModuleFactory( | |
path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |
module_factory_result = factory.get_module() | |
for offline_mode in OfflineSimulationMode: | |
with offline(offline_mode): | |
factory = CachedEvaluationModuleFactory( | |
METRIC_LOADING_SCRIPT_NAME, | |
dynamic_modules_path=self.dynamic_modules_path, | |
) | |
module_factory_result = factory.get_module() | |
assert importlib.import_module(module_factory_result.module_path) is not None | |
def test_cache_with_remote_canonical_module(self): | |
metric = "accuracy" | |
evaluation_module_factory( | |
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |
for offline_mode in OfflineSimulationMode: | |
with offline(offline_mode): | |
evaluation_module_factory( | |
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |
def test_cache_with_remote_community_module(self): | |
metric = "lvwerra/test" | |
evaluation_module_factory( | |
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |
for offline_mode in OfflineSimulationMode: | |
with offline(offline_mode): | |
evaluation_module_factory( | |
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path | |
) | |