net.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MemoryNetwork 作者: aonotas 项目源码 文件源码
def encode(self, x_input, x_query, answer):
        m = self.encode_input(x_input)
        u = self.encode_query(x_query)

        # print "m.data.shape", m.data.shape
        # print "u.data.shape", u.data.shape
        mu = functions.matmul(m, u, transb=True)
        # print "mu.data.shape", mu.data.shape
        # print "mu.data",  mu.data
        p = functions.softmax(mu)
        c = self.encode_output(x_input)
        # print "p.data.shape:", p.data.shape
        # print "c.data.shape:", c.data.shape
        # print "functions.swapaxes(c ,2, 1):", functions.swapaxes(c ,2, 1).data.shape
        o = functions.matmul(functions.swapaxes(c ,1, 0), p) # (2, 50, 1)
        o = functions.swapaxes(o ,1, 0) # (2, 50) 
        # print "u.data.shape:", u.data.shape
        # print "o.data.shape:", o.data.shape
        # print "u.data.shape:", u.data
        # print "o.data.shape:", o.data
        # print (u+o).data.shape
        predict = self.W(u + o)
        # print predict.data.shape
        loss = functions.softmax_cross_entropy(predict, answer)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号