Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for DSN model assembly functions.""" | |
import numpy as np | |
import tensorflow as tf | |
import dsn | |
class HelperFunctionsTest(tf.test.TestCase): | |
def testBasicDomainSeparationStartPoint(self): | |
with self.test_session() as sess: | |
# Test for when global_step < domain_separation_startpoint | |
step = tf.contrib.slim.get_or_create_global_step() | |
sess.run(tf.global_variables_initializer()) # global_step = 0 | |
params = {'domain_separation_startpoint': 2} | |
weight = dsn.dsn_loss_coefficient(params) | |
weight_np = sess.run(weight) | |
self.assertAlmostEqual(weight_np, 1e-10) | |
step_op = tf.assign_add(step, 1) | |
step_np = sess.run(step_op) # global_step = 1 | |
weight = dsn.dsn_loss_coefficient(params) | |
weight_np = sess.run(weight) | |
self.assertAlmostEqual(weight_np, 1e-10) | |
# Test for when global_step >= domain_separation_startpoint | |
step_np = sess.run(step_op) # global_step = 2 | |
tf.logging.info(step_np) | |
weight = dsn.dsn_loss_coefficient(params) | |
weight_np = sess.run(weight) | |
self.assertAlmostEqual(weight_np, 1.0) | |
class DsnModelAssemblyTest(tf.test.TestCase): | |
def _testBuildDefaultModel(self): | |
images = tf.to_float(np.random.rand(32, 28, 28, 1)) | |
labels = {} | |
labels['classes'] = tf.one_hot( | |
tf.to_int32(np.random.randint(0, 9, (32))), 10) | |
params = { | |
'use_separation': True, | |
'layers_to_regularize': 'fc3', | |
'weight_decay': 0.0, | |
'ps_tasks': 1, | |
'domain_separation_startpoint': 1, | |
'alpha_weight': 1, | |
'beta_weight': 1, | |
'gamma_weight': 1, | |
'recon_loss_name': 'sum_of_squares', | |
'decoder_name': 'small_decoder', | |
'encoder_name': 'default_encoder', | |
} | |
return images, labels, params | |
def testBuildModelDann(self): | |
images, labels, params = self._testBuildDefaultModel() | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'dann_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 6) | |
def testBuildModelDannSumOfPairwiseSquares(self): | |
images, labels, params = self._testBuildDefaultModel() | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'dann_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 6) | |
def testBuildModelDannMultiPSTasks(self): | |
images, labels, params = self._testBuildDefaultModel() | |
params['ps_tasks'] = 10 | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'dann_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 6) | |
def testBuildModelMmd(self): | |
images, labels, params = self._testBuildDefaultModel() | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'mmd_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 6) | |
def testBuildModelCorr(self): | |
images, labels, params = self._testBuildDefaultModel() | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'correlation_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 6) | |
def testBuildModelNoDomainAdaptation(self): | |
images, labels, params = self._testBuildDefaultModel() | |
params['use_separation'] = False | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none', | |
params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 1) | |
self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0) | |
def testBuildModelNoAdaptationWeightDecay(self): | |
images, labels, params = self._testBuildDefaultModel() | |
params['use_separation'] = False | |
params['weight_decay'] = 1e-5 | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none', | |
params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 1) | |
self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1) | |
def testBuildModelNoSeparation(self): | |
images, labels, params = self._testBuildDefaultModel() | |
params['use_separation'] = False | |
with self.test_session(): | |
dsn.create_model(images, labels, | |
tf.cast(tf.ones([32,]), tf.bool), images, labels, | |
'dann_loss', params, 'dann_mnist') | |
loss_tensors = tf.contrib.losses.get_losses() | |
self.assertEqual(len(loss_tensors), 2) | |
if __name__ == '__main__': | |
tf.test.main() | |