rte_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Recognizing-Textual-Entailment 作者: codedecde 项目源码 文件源码
def forward(self, premise, hypothesis, training=False):
        '''
        inputs:
            premise : batch x T
            hypothesis : batch x T
        outputs :
            pred : batch x num_classes
        '''
        self.train(training)
        batch_size = premise.size(0)

        mask_p = torch.ne(premise, 0).type(dtype)
        mask_h = torch.ne(hypothesis, 0).type(dtype)

        encoded_p = self.embedding(premise)  # batch x T x n_embed
        encoded_p = F.dropout(encoded_p, p=self.options['DROPOUT'], training=training)

        encoded_h = self.embedding(hypothesis)  # batch x T x n_embed
        encoded_h = F.dropout(encoded_h, p=self.options['DROPOUT'], training=training)

        encoded_p = encoded_p.transpose(1, 0)  # T x batch x n_embed
        encoded_h = encoded_h.transpose(1, 0)  # T x batch x n_embed

        mask_p = mask_p.transpose(1, 0)  # T x batch
        mask_h = mask_h.transpose(1, 0)  # T x batch

        h_0 = self.init_hidden(batch_size)  # 1 x batch x n_dim
        o_p, h_n = self._gru_forward(self.p_gru, encoded_p, mask_p, h_0)  # o_p : T x batch x n_dim
                                                                          # h_n : 1 x batch x n_dim

        o_h, h_n = self._gru_forward(self.h_gru, encoded_h, mask_h, h_n)  # o_h : T x batch x n_dim
                                                                          # h_n : 1 x batch x n_dim

        if self.options['WBW_ATTN']:
            r_0 = self.attn_rnn_init_hidden(batch_size)  # batch x n_dim
            r, alpha_vec = self._attn_rnn_forward(o_h, mask_h, r_0, o_p, mask_p)  # r : batch x n_dim
                                                                                  # alpha_vec : T x batch x T         
        else:
            r, alpha = self._attention_forward(o_p, mask_p, o_h[-1])  # r : batch x n_dim
                                                                      # alpha : batch x T

        h_star = self._combine_last(r, o_h[-1])  # batch x n_dim
        h_star = self.out(h_star)  # batch x num_classes
        if self.options['LAST_NON_LINEAR']:
            h_star = F.leaky_relu(h_star)  # Non linear projection
        pred = F.log_softmax(h_star)
        return pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号