def setUp(self):
super(TestEncoderBase, self).setUp()
self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
self.encoder_base = _EncoderBase(stateful=True)
tensor = Variable(torch.rand([5, 7, 3]))
tensor[1, 6:, :] = 0
tensor[3, 2:, :] = 0
self.tensor = tensor
mask = Variable(torch.ones(5, 7))
mask[1, 6:] = 0
mask[2, :] = 0 # <= completely masked
mask[3, 2:] = 0
mask[4, :] = 0 # <= completely masked
self.mask = mask
self.batch_size = 5
self.num_valid = 3
sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
_, _, restoration_indices, sorting_indices = sort_batch_by_length(tensor, sequence_lengths)
self.sorting_indices = sorting_indices
self.restoration_indices = restoration_indices
评论列表
文章目录