def forward(self, input_h, input_c, mask=None):
'''
Args:
input_h: Tensor
the head input tensor with shape = [batch, length, input_size]
input_c: Tensor
the child input tensor with shape = [batch, length, input_size]
mask: Tensor or None
the mask tensor with shape = [batch, length]
lengths: Tensor or None
the length tensor with shape = [batch]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]
'''
batch, length, _ = input_h.size()
# [batch, num_labels, length, length]
output = self.attention(input_h, input_c, mask_d=mask, mask_e=mask)
# set diagonal elements to -inf
output = output + Variable(torch.diag(output.data.new(length).fill_(-np.inf)))
return output
评论列表
文章目录