def forward(self, input_d, input_e, mask_d=None, mask_e=None):
'''
Args:
input_d: Tensor
the decoder input tensor with shape = [batch, length_decoder, input_size]
input_e: Tensor
the child input tensor with shape = [batch, length_encoder, input_size]
mask_d: Tensor or None
the mask tensor for decoder with shape = [batch, length_decoder]
mask_e: Tensor or None
the mask tensor for encoder with shape = [batch, length_encoder]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]
'''
assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
batch, length_decoder, _ = input_d.size()
_, length_encoder, _ = input_e.size()
# compute decoder part: [batch, length_decoder, input_size_decoder] * [input_size_decoder, hidden_size]
# the output shape is [batch, length_decoder, hidden_size]
# then --> [batch, 1, length_decoder, hidden_size]
out_d = torch.matmul(input_d, self.W_d).unsqueeze(1)
# compute decoder part: [batch, length_encoder, input_size_encoder] * [input_size_encoder, hidden_size]
# the output shape is [batch, length_encoder, hidden_size]
# then --> [batch, length_encoder, 1, hidden_size]
out_e = torch.matmul(input_e, self.W_e).unsqueeze(2)
# add them together [batch, length_encoder, length_decoder, hidden_size]
out = F.tanh(out_d + out_e + self.b)
# product with v
# [batch, length_encoder, length_decoder, hidden_size] * [hidden, num_label]
# [batch, length_encoder, length_decoder, num_labels]
# then --> [batch, num_labels, length_decoder, length_encoder]
return torch.matmul(out, self.v).transpose(1, 3)
评论列表
文章目录