train.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:rationalizing-neural-predictions 作者: hughperkins 项目源码 文件源码
def forward(self, input):
        """
        x should be [seq_len][batch_size]
        """
        seq_len = input.size()[0]
        batch_size = input.size()[1]
        # we reuse initial_state and initial_cell, if they havent changed
        # since last time.
        if self.initial_state is None or self.initial_state.size()[1] != batch_size:
            self.initial_state = autograd.Variable(torch.zeros(
                self.num_layers * 2,
                batch_size,
                self.num_hidden
            ))
            self.initial_cell = autograd.Variable(torch.zeros(
                self.num_layers * 2,
                batch_size,
                self.num_hidden
            ))
            if input.is_cuda:
                self.initial_state = self.initial_state.cuda()
                self.initial_cell = self.initial_cell.cuda()
        x = self.embedding(input)
        x, _ = self.lstm(x, (self.initial_state, self.initial_cell))
        x = self.linear(x)
        x = F.sigmoid(x)
        rationale_selected_node = torch.bernoulli(x)
        rationale_selected = rationale_selected_node.view(seq_len, batch_size)
        rationale_lengths = rationale_selected.sum(dim=0).int()
        max_rationale_length = rationale_lengths.max()
        # if self.rationales is None or self.rationales.shape[1] != batch_size:
        rationales = torch.LongTensor(max_rationale_length.data[0], batch_size)
        if input.is_cuda:
            rationales = rationales.cuda()
        rationales.fill_(self.pad_id)
        for n in range(batch_size):
            this_len = rationale_lengths[n].data[0]
            rationales[:this_len, n] = torch.masked_select(
                input[:, n].data, rationale_selected[:, n].data.byte()
            )
        return rationale_selected_node, rationale_selected, rationales, rationale_lengths
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号