ntm.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:ntm_keras 作者: flomlo 项目源码 文件源码
def _get_weight_vector(self, M, w_tm1, k, beta, g, s, gamma):
#        M = tf.Print(M, [M, w_tm1, k], message='get weights beg1: ')
#        M = tf.Print(M, [beta, g, s, gamma], message='get weights beg2: ')
        # Content adressing, see Chapter 3.3.1:
        num = beta * _cosine_distance(M, k)
        w_c  = K.softmax(num) # It turns out that equation (5) is just softmax.
        # Location adressing, see Chapter 3.3.2:
        # Equation 7:
        w_g = (g * w_c) + (1-g)*w_tm1
        # C_s is the circular convolution
        #C_w = K.sum((self.C[None, :, :, :] * w_g[:, None, None, :]),axis=3)
        # Equation 8:
        # TODO: Explain
        C_s = K.sum(K.repeat_elements(self.C[None, :, :, :], self.batch_size, axis=0) * s[:,:,None,None], axis=1)
        w_tilda = K.batch_dot(C_s, w_g)
        # Equation 9:
        w_out = _renorm(w_tilda ** gamma)

        return w_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号