|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for registry.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from unittest import mock |
|
|
|
from absl.testing import absltest |
|
from big_vision.pp import registry |
|
|
|
|
|
class RegistryTest(absltest.TestCase): |
|
|
|
def setUp(self): |
|
super(RegistryTest, self).setUp() |
|
|
|
|
|
self.addCleanup(mock.patch.stopall) |
|
self.global_registry = dict() |
|
self.mocked_method = mock.patch.object( |
|
registry.Registry, "global_registry", |
|
return_value=self.global_registry).start() |
|
|
|
def test_parse_name(self): |
|
name, args, kwargs = registry.parse_name("f") |
|
self.assertEqual(name, "f") |
|
self.assertEqual(args, ()) |
|
self.assertEqual(kwargs, {}) |
|
|
|
name, args, kwargs = registry.parse_name("f()") |
|
self.assertEqual(name, "f") |
|
self.assertEqual(args, ()) |
|
self.assertEqual(kwargs, {}) |
|
|
|
name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')") |
|
self.assertEqual(name, "func") |
|
self.assertEqual(args, ()) |
|
self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"}) |
|
|
|
name, args, kwargs = registry.parse_name("func(1,'foo',3)") |
|
self.assertEqual(name, "func") |
|
self.assertEqual(args, (1, "foo", 3)) |
|
self.assertEqual(kwargs, {}) |
|
|
|
name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')") |
|
self.assertEqual(name, "func") |
|
self.assertEqual(args, (1, "2")) |
|
self.assertEqual(kwargs, {"a": 3, "foo": "bar"}) |
|
|
|
name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')") |
|
self.assertEqual(name, "foo.bar.func") |
|
self.assertEqual(kwargs, dict(a=0, b=1, c="s")) |
|
|
|
with self.assertRaises(SyntaxError): |
|
registry.parse_name("func(0") |
|
with self.assertRaises(SyntaxError): |
|
registry.parse_name("func(a=0,,b=0)") |
|
with self.assertRaises(SyntaxError): |
|
registry.parse_name("func(a=0,b==1,c='s')") |
|
with self.assertRaises(ValueError): |
|
registry.parse_name("func(a=0,b=undefined_name,c='s')") |
|
|
|
def test_register(self): |
|
|
|
@registry.Registry.register("func1") |
|
def func1(): |
|
pass |
|
|
|
self.assertLen(registry.Registry.global_registry(), 1) |
|
|
|
def test_lookup_function(self): |
|
|
|
@registry.Registry.register("func1") |
|
def func1(arg1, arg2, arg3): |
|
return arg1, arg2, arg3 |
|
|
|
self.assertTrue(callable(registry.Registry.lookup("func1"))) |
|
self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3), |
|
(1, 9, 3)) |
|
|
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2)) |
|
self.assertEqual( |
|
registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2)) |
|
|
|
with self.assertRaises(TypeError): |
|
registry.Registry.lookup("func1(1, arg2=2)")(3) |
|
with self.assertRaises(TypeError): |
|
registry.Registry.lookup("func1(1, arg3=3)")(arg3=3) |
|
with self.assertRaises(TypeError): |
|
registry.Registry.lookup("func1(1, arg3=3)")(arg1=3) |
|
with self.assertRaises(SyntaxError): |
|
registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3) |
|
|
|
|
|
if __name__ == "__main__": |
|
absltest.main() |
|
|