gated_rnn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tf-sparql 作者: derdav3 项目源码 文件源码
def with_batch_norm_control(self, is_training=True, test_local_stats=True):
    """Wraps this RNNCore with the additional control input to the `BatchNorm`s.

    Example usage:

      lstm = nnd.LSTM(4)
      is_training = tf.placeholder(tf.bool)
      rnn_input = ...
      my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)

    Args:
      is_training: Boolean that indicates whether we are in
        training mode or testing mode. When in training mode, the batch norm
        statistics are taken from the given batch, and moving statistics are
        updated. When in testing mode, the moving statistics are not updated,
        and in addition if `test_local_stats` is False then the moving
        statistics are used for the batch statistics. See the `BatchNorm` module
        for more details.
      test_local_stats: Boolean scalar indicated whether to use local
        batch statistics in test mode.

    Returns:
      RNNCell wrapping this class with the extra input(s) added.
    """
    return LSTM.CellWithExtraInput(self,
                                   is_training=is_training,
                                   test_local_stats=test_local_stats)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号