ntm.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: yeoedward 项目源码 文件源码
def address(M0, w0, head):
    # Content focusing
    # Compute cosine similarity
    key = tf.expand_dims(head["key"], 1)
    key_matches = tf.batch_matmul(key, tf.transpose(M0, [0, 2, 1]))
    key_matches = tf.squeeze(key_matches)
    key_mag = tf.expand_dims(NTMCell.magnitude(head["key"], 1), 1)
    M_col_mag = NTMCell.magnitude(M0, 2)
    cosine_sim = key_matches / (key_mag * M_col_mag)
    # Compute content weights
    wc = tf.nn.softmax(head["key_str"] * cosine_sim)

    # Location focusing
    wg = head["interp"] * wc + (1 - head["interp"]) * w0
    ws = rotate.ntm_rotate(wg, head["shift"])
    ws_pow = tf.pow(ws, head["sharp"])

    w1 = ws_pow / tf.reduce_sum(ws_pow, 1, keep_dims=True)


    return w1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号