def forward(self, *hidden_states):
if len(hidden_states) == 1:
hidden_state = hidden_states[0]
return F.softmax(F.tanh(self.projection(hidden_state))) * hidden_state
elif len(hidden_states) == 2:
left_hidden_state, right_hidden_state = hidden_states
if self.mode == 0 or self.mode == 1:
if self.mode == 0:
left_attention_weights = F.softmax(F.tanh(self.projection(left_hidden_state)))
right_attention_weights = F.softmax(F.tanh(self.projection(right_hidden_state)))
elif self.mode == 1:
left_attention_weights = F.softmax(F.tanh(self.left_projection(left_hidden_state)))
right_attention_weights = F.softmax(F.tanh(self.right_projection(right_hidden_state)))
return left_attention_weights * left_hidden_state, right_attention_weights * right_hidden_state
elif self.mode == 2:
hidden_state = torch.cat([left_hidden_state, right_hidden_state], dim=1)
attention_weights = F.softmax(F.tanh(self.projection(hidden_state)))
return attention_weights * left_hidden_state, attention_weights * right_hidden_state
评论列表
文章目录