|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for ops_text.""" |
|
|
|
import copy |
|
|
|
from absl.testing import parameterized |
|
import big_vision.pp.ops_text as pp |
|
from big_vision.pp.registry import Registry |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
class PyToTfWrapper: |
|
"""Allows to use `to_{int,str}_tf()` via `to_{int,str}()`.""" |
|
|
|
def __init__(self, tok): |
|
self.tok = tok |
|
self.bos_token = tok.bos_token |
|
self.eos_token = tok.eos_token |
|
self.vocab_size = tok.vocab_size |
|
|
|
def to_int(self, text, *, bos=False, eos=False): |
|
ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos) |
|
if isinstance(ret, tf.RaggedTensor): |
|
return [t.numpy().tolist() for t in ret] |
|
return ret.numpy().tolist() |
|
|
|
def to_str(self, tokens, stop_at_eos=True): |
|
ret = self.tok.to_str_tf_op( |
|
tf.ragged.constant(tokens), |
|
stop_at_eos=stop_at_eos, |
|
) |
|
if ret.ndim == 0: |
|
return ret.numpy().decode() |
|
return [t.numpy().decode() for t in ret] |
|
|
|
|
|
class PpOpsTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
def tfrun(self, ppfn, data): |
|
|
|
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} |
|
|
|
|
|
|
|
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) |
|
for npdata in tfdata.map(ppfn).as_numpy_iterator(): |
|
yield npdata |
|
|
|
def testtok(self): |
|
|
|
return "test_model.model" |
|
|
|
def test_get_pp_clip_i1k_label_names(self): |
|
op = pp.get_pp_clip_i1k_label_names() |
|
labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() |
|
self.assertAllEqual(labels, ["tench", "goldfish"]) |
|
|
|
@parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"), |
|
(["Decoded Array!"], ["decoded array!"]), |
|
([b"aA", "bB"], [b"aa", "bb"])) |
|
def test_get_lower(self, inputs, expected_output): |
|
op = pp.get_lower() |
|
out = op({"text": tf.constant(inputs)}) |
|
self.assertAllEqual(out["text"].numpy(), np.array(expected_output)) |
|
|
|
@parameterized.named_parameters( |
|
("py", False), |
|
("tf", True), |
|
) |
|
def test_sentencepiece_tokenizer(self, wrap_tok): |
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
if wrap_tok: |
|
tok = PyToTfWrapper(tok) |
|
self.assertEqual(tok.vocab_size, 1000) |
|
bos, eos = tok.bos_token, tok.eos_token |
|
self.assertEqual(bos, 1) |
|
self.assertEqual(eos, 2) |
|
|
|
|
|
self.assertEqual(tok.to_int("blah"), [80, 180, 60]) |
|
self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60]) |
|
self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos]) |
|
self.assertEqual( |
|
tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos] |
|
) |
|
self.assertEqual( |
|
tok.to_int(["blah", "blah blah"]), |
|
[[80, 180, 60], [80, 180, 60, 80, 180, 60]], |
|
) |
|
|
|
|
|
self.assertEqual(tok.to_str([80, 180, 60]), "blah") |
|
self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah") |
|
self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah") |
|
self.assertEqual( |
|
tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]), |
|
["blah", "blah blah"], |
|
) |
|
|
|
def test_sentencepiece_tokenizer_tf_op_ndarray_input(self): |
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
bos, eos = tok.bos_token, tok.eos_token |
|
arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32) |
|
self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2) |
|
|
|
def test_sentencepiece_tokenizer_tokensets(self): |
|
tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"]) |
|
self.assertEqual(tok.vocab_size, 2024) |
|
self.assertEqual( |
|
tok.to_int("blah<loc0000><loc1023>"), [80, 180, 60, 1000, 2023] |
|
) |
|
|
|
def test_sentencepiece_stop_at_eos(self): |
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah") |
|
eos = tok.eos_token |
|
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah") |
|
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b") |
|
self.assertEqual( |
|
tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True), |
|
["b", "bla"] |
|
) |
|
|
|
def test_sentencepiece_extra_tokens(self): |
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah") |
|
tok = pp.SentencepieceTokenizer( |
|
self.testtok(), tokensets=["sp_extra_tokens"] |
|
) |
|
self.assertEqual(tok.vocab_size, 1001) |
|
self.assertEqual( |
|
tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "<s> blah</s>" |
|
) |
|
|
|
|
|
@Registry.register("tokensets.sp_extra_tokens") |
|
def _get_sp_extra_tokens(): |
|
|
|
|
|
|
|
return ["<s>", "</s>", "<pad>"] |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|