seq2seq_model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DSTC6-End-to-End-Conversation-Modeling 作者: dialogtekgeek 项目源码 文件源码
def loss(self,es,x,y,t):
        """ Forward propagation and loss calculation
            Args:
                es (pair of ~chainer.Variable): encoder state 
                x (list of ~chainer.Variable): list of input sequences
                y (list of ~chainer.Variable): list of output sequences
                t (list of ~chainer.Variable): list of target sequences
                                   if t is None, it returns only states
            Return:
                es (pair of ~chainer.Variable(s)): encoder state
                ds (pair of ~chainer.Variable(s)): decoder state
                loss (~chainer.Variable) : cross-entropy loss
        """
        es,ey = self.encoder(es,x)
        ds,dy = self.decoder(es,y)
        if t is not None:
            loss = F.softmax_cross_entropy(dy,t)
            # avoid NaN gradients (See: https://github.com/pfnet/chainer/issues/2505)
            if chainer.config.train:
                loss += F.sum(F.concat(ey, axis=0)) * 0
            return es, ds, loss
        else: # if target is None, it only returns states
            return es, ds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号