def forward(self, query, ref):
"""
Args:
query: is the hidden state of the decoder at the current
time step. batch x dim
ref: the set of hidden states from the encoder.
sourceL x batch x hidden_dim
"""
# ref is now [batch_size x hidden_dim x sourceL]
ref = ref.permute(1, 2, 0)
q = self.project_query(query).unsqueeze(2) # batch x dim x 1
e = self.project_ref(ref) # batch_size x hidden_dim x sourceL
# expand the query by sourceL
# batch x dim x sourceL
expanded_q = q.repeat(1, 1, e.size(2))
# batch x 1 x hidden_dim
v_view = self.v.unsqueeze(0).expand(
expanded_q.size(0), len(self.v)).unsqueeze(1)
# [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL]
u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1)
if self.use_tanh:
logits = self.C * self.tanh(u)
else:
logits = u
return e, logits
neural_combinatorial_rl.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录