polyphonic_data_loader.py 文件源码

python
阅读 80 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data)
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \
            else Variable(torch.LongTensor(time_slice))
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch


# this function takes the hidden state as output by the PyTorch rnn and
# unpacks it it; it also reverses each sequence temporally
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号