gated_rnn.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tf-sparql 作者: derdav3 项目源码 文件源码
def _build(self, inputs, index, is_training, test_local_stats):
      """Add the IndexedStatsBatchNorm module to the graph.

      Args:
        inputs: Tensor to apply batch norm to.
        index: Scalar TensorFlow int32 value to select the batch norm index.
        is_training: Boolean to indicate to `nn.BatchNorm` if we are
          currently training.
        test_local_stats: Boolean to indicate to `nn.BatchNorm` if batch
          normalization should  use local batch statistics at test time.

      Returns:
        Output of batch norm operation.
      """
      def create_batch_norm():
        return batch_norm.BatchNorm(offset=False, scale=False)(
            inputs, is_training, test_local_stats)

      if self._max_unique_stats > 1:
        pred_fn_pairs = [(tf.equal(i, index), create_batch_norm)
                         for i in xrange(self._max_unique_stats - 1)]
        out = tf.case(pred_fn_pairs, create_batch_norm)
        out.set_shape(inputs.get_shape())  # needed for tf.case shape inference
        return out
      else:
        return create_batch_norm()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号