def with_batch_norm_control(self, is_training=True, test_local_stats=True):
"""Wraps this RNNCore with the additional control input to the `BatchNorm`s.
Example usage:
lstm = nnd.LSTM(4)
is_training = tf.placeholder(tf.bool)
rnn_input = ...
my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)
Args:
is_training: Boolean that indicates whether we are in
training mode or testing mode. When in training mode, the batch norm
statistics are taken from the given batch, and moving statistics are
updated. When in testing mode, the moving statistics are not updated,
and in addition if `test_local_stats` is False then the moving
statistics are used for the batch statistics. See the `BatchNorm` module
for more details.
test_local_stats: Boolean scalar indicated whether to use local
batch statistics in test mode.
Returns:
RNNCell wrapping this class with the extra input(s) added.
"""
return LSTM.CellWithExtraInput(self,
is_training=is_training,
test_local_stats=test_local_stats)
评论列表
文章目录