layers.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:repeval_rivercorners 作者: jabalazs 项目源码 文件源码
def forward(self, lstm_out, lengths):
        """

        Args:
            lstm_out: A Variable containing a 3D tensor of dimension
                (seq_len, batch_size, hidden_x_dirs)
            lengths: A Variable containing 1D LongTensor of dimension
                (batch_size)

        Return:
            A Variable containing a 2D tensor of the same type as lstm_out of
            dim (batch_size, hidden_x_dirs) corresponding to the concatenated
            last hidden states of the forward and backward parts of the input.
        """

        seq_len = lstm_out.size(0)
        batch_size = lstm_out.size(1)
        hidden_x_dirs = lstm_out.size(2)
        single_dir_hidden = hidden_x_dirs / 2

        lengths_fw = lengths
        lengths_bw = seq_len - lengths_fw

        rep_lengths_fw = lengths_fw.view(1, batch_size, 1)
        rep_lengths_fw = rep_lengths_fw.repeat(1, 1, single_dir_hidden)

        rep_lengths_bw = lengths_bw.view(1, batch_size, 1)
        rep_lengths_bw = rep_lengths_bw.repeat(1, 1, single_dir_hidden)

        # we want 2 chunks in the last dimension
        out_fw, out_bw = torch.chunk(lstm_out, 2, 2)

        h_t_fw = torch.gather(out_fw, 0, rep_lengths_fw-1)
        h_t_bw = torch.gather(out_bw, 0, rep_lengths_bw)

        # -> (batch_size, hidden_x_dirs)
        last_hidden_out = torch.cat([h_t_fw, h_t_bw], 2).squeeze()
        return last_hidden_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号