Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions to create a DSN model and add the different losses to it. | |
Specifically, in this file we define the: | |
- Shared Encoding Similarity Loss Module, with: | |
- The MMD Similarity method | |
- The Correlation Similarity method | |
- The Gradient Reversal (Domain-Adversarial) method | |
- Difference Loss Module | |
- Reconstruction Loss Module | |
- Task Loss Module | |
""" | |
from functools import partial | |
import tensorflow as tf | |
import losses | |
import models | |
import utils | |
slim = tf.contrib.slim | |
################################################################################ | |
# HELPER FUNCTIONS | |
################################################################################ | |
def dsn_loss_coefficient(params): | |
"""The global_step-dependent weight that specifies when to kick in DSN losses. | |
Args: | |
params: A dictionary of parameters. Expecting 'domain_separation_startpoint' | |
Returns: | |
A weight to that effectively enables or disables the DSN-related losses, | |
i.e. similarity, difference, and reconstruction losses. | |
""" | |
return tf.where( | |
tf.less(slim.get_or_create_global_step(), | |
params['domain_separation_startpoint']), 1e-10, 1.0) | |
################################################################################ | |
# MODEL CREATION | |
################################################################################ | |
def create_model(source_images, source_labels, domain_selection_mask, | |
target_images, target_labels, similarity_loss, params, | |
basic_tower_name): | |
"""Creates a DSN model. | |
Args: | |
source_images: images from the source domain, a tensor of size | |
[batch_size, height, width, channels] | |
source_labels: a dictionary with the name, tensor pairs. 'classes' is one- | |
hot for the number of classes. | |
domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes | |
the labeled images that belong to the source domain. | |
target_images: images from the target domain, a tensor of size | |
[batch_size, height width, channels]. | |
target_labels: a dictionary with the name, tensor pairs. | |
similarity_loss: The type of method to use for encouraging | |
the codes from the shared encoder to be similar. | |
params: A dictionary of parameters. Expecting 'weight_decay', | |
'layers_to_regularize', 'use_separation', 'domain_separation_startpoint', | |
'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name', | |
'decoder_name', 'encoder_name' | |
basic_tower_name: the name of the tower to use for the shared encoder. | |
Raises: | |
ValueError: if the arch is not one of the available architectures. | |
""" | |
network = getattr(models, basic_tower_name) | |
num_classes = source_labels['classes'].get_shape().as_list()[1] | |
# Make sure we are using the appropriate number of classes. | |
network = partial(network, num_classes=num_classes) | |
# Add the classification/pose estimation loss to the source domain. | |
source_endpoints = add_task_loss(source_images, source_labels, network, | |
params) | |
if similarity_loss == 'none': | |
# No domain adaptation, we can stop here. | |
return | |
with tf.variable_scope('towers', reuse=True): | |
target_logits, target_endpoints = network( | |
target_images, weight_decay=params['weight_decay'], prefix='target') | |
# Plot target accuracy of the train set. | |
target_accuracy = utils.accuracy( | |
tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1)) | |
if 'quaternions' in target_labels: | |
target_quaternion_loss = losses.log_quaternion_loss( | |
target_labels['quaternions'], target_endpoints['quaternion_pred'], | |
params) | |
tf.summary.scalar('eval/Target quaternions', target_quaternion_loss) | |
tf.summary.scalar('eval/Target accuracy', target_accuracy) | |
source_shared = source_endpoints[params['layers_to_regularize']] | |
target_shared = target_endpoints[params['layers_to_regularize']] | |
# When using the semisupervised model we include labeled target data in the | |
# source classifier. We do not want to include these target domain when | |
# we use the similarity loss. | |
indices = tf.range(0, source_shared.get_shape().as_list()[0]) | |
indices = tf.boolean_mask(indices, domain_selection_mask) | |
add_similarity_loss(similarity_loss, | |
tf.gather(source_shared, indices), | |
tf.gather(target_shared, indices), params) | |
if params['use_separation']: | |
add_autoencoders( | |
source_images, | |
source_shared, | |
target_images, | |
target_shared, | |
params=params,) | |
def add_similarity_loss(method_name, | |
source_samples, | |
target_samples, | |
params, | |
scope=None): | |
"""Adds a loss encouraging the shared encoding from each domain to be similar. | |
Args: | |
method_name: the name of the encoding similarity method to use. Valid | |
options include `dann_loss', `mmd_loss' or `correlation_loss'. | |
source_samples: a tensor of shape [num_samples, num_features]. | |
target_samples: a tensor of shape [num_samples, num_features]. | |
params: a dictionary of parameters. Expecting 'gamma_weight'. | |
scope: optional name scope for summary tags. | |
Raises: | |
ValueError: if `method_name` is not recognized. | |
""" | |
weight = dsn_loss_coefficient(params) * params['gamma_weight'] | |
method = getattr(losses, method_name) | |
method(source_samples, target_samples, weight, scope) | |
def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain): | |
"""Adds a reconstruction loss. | |
Args: | |
recon_loss_name: The name of the reconstruction loss. | |
images: A `Tensor` of size [batch_size, height, width, 3]. | |
recons: A `Tensor` whose size matches `images`. | |
weight: A scalar coefficient for the loss. | |
domain: The name of the domain being reconstructed. | |
Raises: | |
ValueError: If `recon_loss_name` is not recognized. | |
""" | |
if recon_loss_name == 'sum_of_pairwise_squares': | |
loss_fn = tf.contrib.losses.mean_pairwise_squared_error | |
elif recon_loss_name == 'sum_of_squares': | |
loss_fn = tf.contrib.losses.mean_squared_error | |
else: | |
raise ValueError('recon_loss_name value [%s] not recognized.' % | |
recon_loss_name) | |
loss = loss_fn(recons, images, weight) | |
assert_op = tf.Assert(tf.is_finite(loss), [loss]) | |
with tf.control_dependencies([assert_op]): | |
tf.summary.scalar('losses/%s Recon Loss' % domain, loss) | |
def add_autoencoders(source_data, source_shared, target_data, target_shared, | |
params): | |
"""Adds the encoders/decoders for our domain separation model w/ incoherence. | |
Args: | |
source_data: images from the source domain, a tensor of size | |
[batch_size, height, width, channels] | |
source_shared: a tensor with first dimension batch_size | |
target_data: images from the target domain, a tensor of size | |
[batch_size, height, width, channels] | |
target_shared: a tensor with first dimension batch_size | |
params: A dictionary of parameters. Expecting 'layers_to_regularize', | |
'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name', | |
'encoder_name', 'weight_decay' | |
""" | |
def normalize_images(images): | |
images -= tf.reduce_min(images) | |
return images / tf.reduce_max(images) | |
def concat_operation(shared_repr, private_repr): | |
return shared_repr + private_repr | |
mu = dsn_loss_coefficient(params) | |
# The layer to concatenate the networks at. | |
concat_layer = params['layers_to_regularize'] | |
# The coefficient for modulating the private/shared difference loss. | |
difference_loss_weight = params['beta_weight'] * mu | |
# The reconstruction weight. | |
recon_loss_weight = params['alpha_weight'] * mu | |
# The reconstruction loss to use. | |
recon_loss_name = params['recon_loss_name'] | |
# The decoder/encoder to use. | |
decoder_name = params['decoder_name'] | |
encoder_name = params['encoder_name'] | |
_, height, width, _ = source_data.get_shape().as_list() | |
code_size = source_shared.get_shape().as_list()[-1] | |
weight_decay = params['weight_decay'] | |
encoder_fn = getattr(models, encoder_name) | |
# Target Auto-encoding. | |
with tf.variable_scope('source_encoder'): | |
source_endpoints = encoder_fn( | |
source_data, code_size, weight_decay=weight_decay) | |
with tf.variable_scope('target_encoder'): | |
target_endpoints = encoder_fn( | |
target_data, code_size, weight_decay=weight_decay) | |
decoder_fn = getattr(models, decoder_name) | |
decoder = partial( | |
decoder_fn, | |
height=height, | |
width=width, | |
channels=source_data.get_shape().as_list()[-1], | |
weight_decay=weight_decay) | |
# Source Auto-encoding. | |
source_private = source_endpoints[concat_layer] | |
target_private = target_endpoints[concat_layer] | |
with tf.variable_scope('decoder'): | |
source_recons = decoder(concat_operation(source_shared, source_private)) | |
with tf.variable_scope('decoder', reuse=True): | |
source_private_recons = decoder( | |
concat_operation(tf.zeros_like(source_private), source_private)) | |
source_shared_recons = decoder( | |
concat_operation(source_shared, tf.zeros_like(source_shared))) | |
with tf.variable_scope('decoder', reuse=True): | |
target_recons = decoder(concat_operation(target_shared, target_private)) | |
target_shared_recons = decoder( | |
concat_operation(target_shared, tf.zeros_like(target_shared))) | |
target_private_recons = decoder( | |
concat_operation(tf.zeros_like(target_private), target_private)) | |
losses.difference_loss( | |
source_private, | |
source_shared, | |
weight=difference_loss_weight, | |
name='Source') | |
losses.difference_loss( | |
target_private, | |
target_shared, | |
weight=difference_loss_weight, | |
name='Target') | |
add_reconstruction_loss(recon_loss_name, source_data, source_recons, | |
recon_loss_weight, 'source') | |
add_reconstruction_loss(recon_loss_name, target_data, target_recons, | |
recon_loss_weight, 'target') | |
# Add summaries | |
source_reconstructions = tf.concat( | |
axis=2, | |
values=map(normalize_images, [ | |
source_data, source_recons, source_shared_recons, | |
source_private_recons | |
])) | |
target_reconstructions = tf.concat( | |
axis=2, | |
values=map(normalize_images, [ | |
target_data, target_recons, target_shared_recons, | |
target_private_recons | |
])) | |
tf.summary.image( | |
'Source Images:Recons:RGB', | |
source_reconstructions[:, :, :, :3], | |
max_outputs=10) | |
tf.summary.image( | |
'Target Images:Recons:RGB', | |
target_reconstructions[:, :, :, :3], | |
max_outputs=10) | |
if source_reconstructions.get_shape().as_list()[3] == 4: | |
tf.summary.image( | |
'Source Images:Recons:Depth', | |
source_reconstructions[:, :, :, 3:4], | |
max_outputs=10) | |
tf.summary.image( | |
'Target Images:Recons:Depth', | |
target_reconstructions[:, :, :, 3:4], | |
max_outputs=10) | |
def add_task_loss(source_images, source_labels, basic_tower, params): | |
"""Adds a classification and/or pose estimation loss to the model. | |
Args: | |
source_images: images from the source domain, a tensor of size | |
[batch_size, height, width, channels] | |
source_labels: labels from the source domain, a tensor of size [batch_size]. | |
or a tuple of (quaternions, class_labels) | |
basic_tower: a function that creates the single tower of the model. | |
params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'. | |
Returns: | |
The source endpoints. | |
Raises: | |
RuntimeError: if basic tower does not support pose estimation. | |
""" | |
with tf.variable_scope('towers'): | |
source_logits, source_endpoints = basic_tower( | |
source_images, weight_decay=params['weight_decay'], prefix='Source') | |
if 'quaternions' in source_labels: # We have pose estimation as well | |
if 'quaternion_pred' not in source_endpoints: | |
raise RuntimeError('Please use a model for estimation e.g. pose_mini') | |
loss = losses.log_quaternion_loss(source_labels['quaternions'], | |
source_endpoints['quaternion_pred'], | |
params) | |
assert_op = tf.Assert(tf.is_finite(loss), [loss]) | |
with tf.control_dependencies([assert_op]): | |
quaternion_loss = loss | |
tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss) | |
slim.losses.add_loss(quaternion_loss * params['pose_weight']) | |
tf.summary.scalar('losses/quaternion_loss', quaternion_loss) | |
classification_loss = tf.losses.softmax_cross_entropy( | |
source_labels['classes'], source_logits) | |
tf.summary.scalar('losses/classification_loss', classification_loss) | |
return source_endpoints | |