dynamic2.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-dynamic-batching 作者: mrdrozdov-github 项目源码 文件源码
def forward(self, x, lengths):
        batch_size = x.size(0)
        max_len = max(lengths)

        emb = Variable(torch.from_numpy(
            self.initial_embeddings.take(x.numpy(), 0)),
            volatile=not self.training)

        outputs = [Variable(torch.zeros(batch_size, self.model_dim).float(), volatile=not self.training)]

        for t in range(max_len):
            choose = torch.ByteTensor(batch_size)
            indices = []
            not_indices = []
            for i, l in enumerate(lengths):
                if l >= max(lengths) - t:
                    indices.append(i)
                    choose[i] = 1
                else:
                    not_indices.append(i)
                    choose[i] = 0

            # Build batch.
            batch = torch.index_select(emb[:,t,:], 0, Variable(torch.LongTensor(indices), volatile=not self.training))
            h_prev = torch.index_select(outputs[-1], 0, Variable(torch.LongTensor(indices), volatile=not self.training))
            h_next = self.rnn(batch, h_prev)

            # Some preparation for output for next step.
            if len(not_indices) > 0:
                not_h_prev = torch.index_select(outputs[-1], 0, Variable(torch.LongTensor(not_indices), volatile=not self.training))
                _not_h_prev = torch.chunk(not_h_prev, len(not_indices))
            _h_next = torch.chunk(h_next, len(indices))

            # Make variable for next step.
            _h = []
            _h_next_idx = 0
            _not_h_prev_idx = 0
            for c in choose:
                if c == 1:
                    _h.append(_h_next[_h_next_idx])
                    _h_next_idx += 1
                else:
                    _h.append(_not_h_prev[_not_h_prev_idx])
                    _not_h_prev_idx += 1
            h = torch.cat(_h, 0)

            outputs.append(h)

        hn = outputs[-1]
        h = F.relu(self.l0(F.dropout(hn, 0.5, self.training)))
        h = F.relu(self.l1(F.dropout(h, 0.5, self.training)))
        y = F.log_softmax(h)
        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号