nn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dqa-net 作者: allenai 项目源码 文件源码
def softmax_with_base(shape, base_untiled, x, mask=None, name='sig'):
    if mask is not None:
        x += VERY_SMALL_NUMBER * (1.0 - mask)
    base_shape = shape[:-1] + [1]
    for _ in shape:
        base_untiled = tf.expand_dims(base_untiled, -1)
    base = tf.tile(base_untiled, base_shape)

    c_shape = shape[:-1] + [shape[-1] + 1]
    c = tf.concat(len(shape)-1, [base, x])
    c_flat = tf.reshape(c, [reduce(mul, shape[:-1], 1), c_shape[-1]])
    p_flat = tf.nn.softmax(c_flat)
    p_cat = tf.reshape(p_flat, c_shape)
    s_aug = tf.slice(p_cat, [0 for _ in shape], [i for i in shape[:-1]] + [1])
    s = tf.squeeze(s_aug, [len(shape)-1])
    sig = tf.sub(1.0, s, name="sig")
    p = tf.slice(p_cat, [0 for _ in shape[:-1]] + [1], shape)
    return sig, p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号