def _build(self, inputs, index, is_training, test_local_stats):
"""Add the IndexedStatsBatchNorm module to the graph.
Args:
inputs: Tensor to apply batch norm to.
index: Scalar TensorFlow int32 value to select the batch norm index.
is_training: Boolean to indicate to `snt.BatchNorm` if we are
currently training.
test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
normalization should use local batch statistics at test time.
Returns:
Output of batch norm operation.
"""
def create_batch_norm():
return batch_norm.BatchNorm(offset=False, scale=False)(
inputs, is_training, test_local_stats)
if self._max_unique_stats > 1:
pred_fn_pairs = [(tf.equal(i, index), create_batch_norm)
for i in xrange(self._max_unique_stats - 1)]
out = tf.case(pred_fn_pairs, create_batch_norm)
out.set_shape(inputs.get_shape()) # needed for tf.case shape inference
return out
else:
return create_batch_norm()
评论列表
文章目录