NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame contribute delete
5.9 kB
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN components."""
import numpy as np
import tensorflow as tf
#from models.domain_adaptation.domain_separation
import models
class SharedEncodersTest(tf.test.TestCase):
def _testSharedEncoder(self,
input_shape=[5, 28, 28, 1],
model=models.dann_mnist,
is_training=True):
images = tf.to_float(np.random.rand(*input_shape))
with self.test_session() as sess:
logits, _ = model(images)
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
return logits_np
def testBuildGRLMnistModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_mnist'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLSvhnModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_svhn'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLGtsrbModel(self):
logits = self._testSharedEncoder([5, 40, 40, 3],
getattr(models, 'dann_gtsrb'))
self.assertEqual(logits.shape, (5, 43))
self.assertTrue(np.any(logits))
def testBuildPoseModel(self):
logits = self._testSharedEncoder([5, 64, 64, 4],
getattr(models, 'dsn_cropped_linemod'))
self.assertEqual(logits.shape, (5, 11))
self.assertTrue(np.any(logits))
def testBuildPoseModelWithBatchNorm(self):
images = tf.to_float(np.random.rand(10, 64, 64, 4))
with self.test_session() as sess:
logits, _ = getattr(models, 'dsn_cropped_linemod')(
images, batch_norm_params=models.default_batch_norm_params(True))
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
self.assertEqual(logits_np.shape, (10, 11))
self.assertTrue(np.any(logits_np))
class EncoderTest(tf.test.TestCase):
def _testEncoder(self, batch_norm_params=None, channels=1):
images = tf.to_float(np.random.rand(10, 28, 28, channels))
with self.test_session() as sess:
end_points = models.default_encoder(
images, 128, batch_norm_params=batch_norm_params)
sess.run(tf.global_variables_initializer())
private_code = sess.run(end_points['fc3'])
self.assertEqual(private_code.shape, (10, 128))
self.assertTrue(np.any(private_code))
self.assertTrue(np.all(np.isfinite(private_code)))
def testEncoder(self):
self._testEncoder()
def testEncoderMultiChannel(self):
self._testEncoder(None, 4)
def testEncoderIsTrainingBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(True))
def testEncoderBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(False))
class DecoderTest(tf.test.TestCase):
def _testDecoder(self,
height=64,
width=64,
channels=4,
batch_norm_params=None,
decoder=models.small_decoder):
codes = tf.to_float(np.random.rand(32, 100))
with self.test_session() as sess:
output = decoder(
codes,
height=height,
width=width,
channels=channels,
batch_norm_params=batch_norm_params)
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertEqual(output_np.shape, (32, height, width, channels))
self.assertTrue(np.any(output_np))
self.assertTrue(np.all(np.isfinite(output_np)))
def testSmallDecoder(self):
self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
def testSmallDecoderThreeChannels(self):
self._testDecoder(28, 28, 3)
def testSmallDecoderBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
def testSmallDecoderIsTrainingBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
def testLargeDecoder(self):
self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
def testLargeDecoderThreeChannels(self):
self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
def testLargeDecoderBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(False),
getattr(models, 'large_decoder'))
def testLargeDecoderIsTrainingBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(True),
getattr(models, 'large_decoder'))
def testGtsrbDecoder(self):
self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
def testGtsrbDecoderBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(False),
getattr(models, 'gtsrb_decoder'))
def testGtsrbDecoderIsTrainingBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(True),
getattr(models, 'gtsrb_decoder'))
if __name__ == '__main__':
tf.test.main()