normalization.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tensorbayes 作者: RuiShu 项目源码 文件源码
def lookup_shift(x,
                 context,
                 shift=True,
                 scale=True,
                 scope=None,
                 reuse=None):

    B = context._shape_as_list()[-1]
    C = x._shape_as_list()[-1]
    ndim = len(x.shape)
    var_shape = [B] + [1] * (ndim - 2) + [C]

    with tf.variable_scope(scope, 'lookup_shift', reuse=reuse):
        output = x
        ids = tf.argmax(context, -1)

        if scale:
            gamma = tf.get_variable('gamma', var_shape, initializer=tf.ones_initializer)
            output *= tf.nn.embedding_lookup(gamma, ids)

        if shift:
            beta = tf.get_variable('beta', var_shape, initializer=tf.zeros_initializer)
            output += tf.nn.embedding_lookup(beta, ids)

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号