|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for ops_image.""" |
|
|
|
import copy |
|
import io |
|
|
|
import big_vision.pp.ops_image as pp |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
def get_image_data(): |
|
img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) |
|
return {"image": tf.cast(img, tf.uint8)} |
|
|
|
|
|
class PreprocessOpsTest(tf.test.TestCase): |
|
|
|
def tfrun(self, ppfn, data): |
|
|
|
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} |
|
|
|
|
|
|
|
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) |
|
for npdata in tfdata.map(ppfn).as_numpy_iterator(): |
|
yield npdata |
|
|
|
def test_resize(self): |
|
for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()): |
|
self.assertEqual(data["image"].shape, (120, 80, 3)) |
|
|
|
def test_resize_small(self): |
|
for data in self.tfrun(pp.get_resize_small(240), get_image_data()): |
|
self.assertEqual(data["image"].shape, (320, 240, 3)) |
|
|
|
def test_resize_long(self): |
|
for data in self.tfrun(pp.get_resize_long(320), get_image_data()): |
|
self.assertEqual(data["image"].shape, (320, 240, 3)) |
|
|
|
def test_inception_crop(self): |
|
for data in self.tfrun(pp.get_inception_crop(), get_image_data()): |
|
self.assertEqual(data["image"].shape[-1], 3) |
|
|
|
def test_decode_jpeg_and_inception_crop(self): |
|
f = io.BytesIO() |
|
plt.imsave(f, get_image_data()["image"].numpy(), format="jpg") |
|
data = {"image": tf.cast(f.getvalue(), tf.string)} |
|
for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data): |
|
self.assertEqual(data["image"].shape[-1], 3) |
|
|
|
def test_random_crop(self): |
|
for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()): |
|
self.assertEqual(data["image"].shape, (120, 80, 3)) |
|
|
|
def test_central_crop(self): |
|
for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()): |
|
self.assertEqual(data["image"].shape, (20, 80, 3)) |
|
|
|
def test_random_flip_lr(self): |
|
data_orig = get_image_data() |
|
for data in self.tfrun(pp.get_random_flip_lr(), data_orig): |
|
self.assertTrue( |
|
np.all(data_orig["image"].numpy() == data["image"]) or |
|
np.all(data_orig["image"].numpy() == data["image"][:, ::-1])) |
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|