def forward(self, input):
"""
x should be [seq_len][batch_size]
"""
seq_len = input.size()[0]
batch_size = input.size()[1]
# we reuse initial_state and initial_cell, if they havent changed
# since last time.
if self.initial_state is None or self.initial_state.size()[1] != batch_size:
self.initial_state = autograd.Variable(torch.zeros(
self.num_layers * 2,
batch_size,
self.num_hidden
))
self.initial_cell = autograd.Variable(torch.zeros(
self.num_layers * 2,
batch_size,
self.num_hidden
))
if input.is_cuda:
self.initial_state = self.initial_state.cuda()
self.initial_cell = self.initial_cell.cuda()
x = self.embedding(input)
x, _ = self.lstm(x, (self.initial_state, self.initial_cell))
x = self.linear(x)
x = F.sigmoid(x)
rationale_selected_node = torch.bernoulli(x)
rationale_selected = rationale_selected_node.view(seq_len, batch_size)
rationale_lengths = rationale_selected.sum(dim=0).int()
max_rationale_length = rationale_lengths.max()
# if self.rationales is None or self.rationales.shape[1] != batch_size:
rationales = torch.LongTensor(max_rationale_length.data[0], batch_size)
if input.is_cuda:
rationales = rationales.cuda()
rationales.fill_(self.pad_id)
for n in range(batch_size):
this_len = rationale_lengths[n].data[0]
rationales[:this_len, n] = torch.masked_select(
input[:, n].data, rationale_selected[:, n].data.byte()
)
return rationale_selected_node, rationale_selected, rationales, rationale_lengths
train.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录