residual_lstm_model_MNIST_dataset.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Stacked_LSTMS_Highway_Residual_On_TimeSeries_Datasets 作者: praveendareddy21 项目源码 文件源码
def apply_batch_norm(input_tensor, config, i):

    with tf.variable_scope("batch_norm") as scope:
        if i != 0 :
            # Do not create extra variables for each time step
            scope.reuse_variables()

        # Mean and variance normalisation simply crunched over all axes
        axes = list(range(len(input_tensor.get_shape())))

        mean, variance = tf.nn.moments(input_tensor, axes=axes, shift=None, name=None, keep_dims=False)
        stdev = tf.sqrt(variance + 0.001)

        # Rescaling
        bn = input_tensor - mean
        bn /= stdev
        # Learnable extra rescaling

        # tf.get_variable("relu_fc_weights", initializer=tf.random_normal(mean=0.0, stddev=0.0)
        bn *= tf.get_variable("a_noreg", initializer=tf.random_normal([1], mean=0.5, stddev=0.0))
        bn += tf.get_variable("b_noreg", initializer=tf.random_normal([1], mean=0.0, stddev=0.0))
        # bn *= tf.Variable(0.5, name=(scope.name + "/a_noreg"))
        # bn += tf.Variable(0.0, name=(scope.name + "/b_noreg"))

    return bn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号