normalization.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensorbayes 作者: RuiShu 项目源码 文件源码
def context_shift(x,
                  context,
                  shift=True,
                  scale=True,
                  scope=None,
                  reuse=None):

    B = context._shape_as_list()[-1]
    C = x._shape_as_list()[-1]
    ndim = len(x.shape)
    var_shape = [B] + [1] * (ndim - 2) + [C]

    with tf.variable_scope(scope, 'context_shift', reuse=reuse):
        output = x

        if scale:
            gamma = tf.get_variable('gamma', var_shape, initializer=tf.ones_initializer)
            output *= tf.tensordot(context, gamma, 1)

        if shift:
            beta = tf.get_variable('beta', var_shape, initializer=tf.zeros_initializer)
            output += tf.tensordot(context, beta, 1)

        output.set_shape(x.get_shape())

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号