def address(M0, w0, head):
# Content focusing
# Compute cosine similarity
key = tf.expand_dims(head["key"], 1)
key_matches = tf.batch_matmul(key, tf.transpose(M0, [0, 2, 1]))
key_matches = tf.squeeze(key_matches)
key_mag = tf.expand_dims(NTMCell.magnitude(head["key"], 1), 1)
M_col_mag = NTMCell.magnitude(M0, 2)
cosine_sim = key_matches / (key_mag * M_col_mag)
# Compute content weights
wc = tf.nn.softmax(head["key_str"] * cosine_sim)
# Location focusing
wg = head["interp"] * wc + (1 - head["interp"]) * w0
ws = rotate.ntm_rotate(wg, head["shift"])
ws_pow = tf.pow(ws, head["sharp"])
w1 = ws_pow / tf.reduce_sum(ws_pow, 1, keep_dims=True)
return w1
评论列表
文章目录