sequence2sequence.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def get_batch_loss(self, input_batch, output_batch):

        dynet.renew_cg()

        # Dimension: maxSentLength * minibatch_size
        wids = []
        wids_reversed = []

        # List of lists to store whether an input is
        # present(1)/absent(0) for an example at a time step
        # masks = [] # Dimension: maxSentLength * minibatch_size

        # tot_words = 0
        maxSentLength = max([len(sent) for sent in input_batch])
        for j in range(maxSentLength):
            wids.append([(self.src_vocab[sent[j]].i if len(sent)>j else self.src_vocab.END_TOK.i) for sent in input_batch])
            wids_reversed.append([(self.src_vocab[sent[len(sent)- j-1]].i if len(sent)>j else self.src_vocab.END_TOK.i) for sent in input_batch])
            # mask = [(1 if len(sent)>j else 0) for sent in input_batch]
            # masks.append(mask)
            #tot_words += sum(mask)

        embedded_batch = self.embed_batch_seq(wids)
        embedded_batch_reverse = self.embed_batch_seq(wids_reversed)
        encoded_batch = self.encode_batch_seq(embedded_batch, embedded_batch_reverse)

        # pass last hidden state of encoder to decoder
        return self.decode_batch(encoded_batch, output_batch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号