ntm_minibatch.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:CopyNet 作者: MultiPath 项目源码 文件源码
def __call__(self, X, w_temp, m_temp):
        # input dimensions
        # X:      (nb_samples, input_dim)
        # w_temp: (nb_samples, memory_dim)
        # m_temp: (nb_samples, memory_dim, memory_width) ::tensor_memory

        key = dot(X, self.W_key, self.b_key)  # (nb_samples, memory_width)
        shift = self.softmax(
            dot(X, self.W_shift, self.b_shift))  # (nb_samples, shift_width)

        beta = self.softplus(dot(X, self.W_beta, self.b_beta))[:, None]  # (nb_samples, x)
        gamma = self.softplus(dot(X, self.W_gama, self.b_gama)) + 1.  # (nb_samples,)
        gamma = gamma[:, None]  # (nb_samples, x)
        g = self.sigmoid(dot(X, self.W_g, self.b_g))[:, None]  # (nb_samples, x)

        signal = [key, shift, beta, gamma, g]

        w_c = self.softmax(
            beta * cosine_sim2d(key, m_temp))  # (nb_samples, memory_dim) //content-based addressing
        w_g = g * w_c + (1 - g) * w_temp  # (nb_samples, memory_dim) //history interpolation
        w_s = shift_convolve2d(w_g, shift, self.shift_conv)  # (nb_samples, memory_dim) //convolutional shift
        w_p = w_s ** gamma  # (nb_samples, memory_dim) //sharpening
        w_t = w_p / T.sum(w_p, axis=1)[:, None]  # (nb_samples, memory_dim)
        return w_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号