recursive.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:dl4nlp_in_theano 作者: luyaojie 项目源码 文件源码
def forward(self, x, seq):
        """
        :param x:   (length, dim)
        :param seq: (length - 1, 3)
        :return:
        """
        # (length, dim) -> (2 * length - 1, dim)
        vector = T.concatenate([x, T.zeros_like(x)[:-1, :]], axis=0)
        # vector = theano.printing.Print()(vector)
        # scan length-1 times
        hs, _ = theano.scan(fn=self.encode,
                            sequences=seq,
                            outputs_info=[vector, shared_scalar(0)],
                            name="compose_phrase")
        comp_vec_init = hs[0][-1][-1]
        comp_rec_init = T.sum(hs[1])
        if self.normalize:
            hidden = x[0] / x[0].norm(2)
        else:
            hidden = x[0]
        comp_vec = ifelse(x.shape[0] > 1, comp_vec_init, hidden)
        comp_rec = ifelse(x.shape[0] > 1, comp_rec_init, shared_zero_scalar())
        return comp_vec, comp_rec
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号