def forward(self, ctx):
idx = np.arange(ctx.size(0) // 2)
# extract counts and values
cnt_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 0)))
val_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 1)))
cnt = ctx.index_select(0, cnt_idx)
val = ctx.index_select(0, val_idx)
# embed counts and values
cnt_emb = self.cnt_enc(cnt)
val_emb = self.val_enc(val)
# element wise multiplication to get a hidden state
h = torch.mul(cnt_emb, val_emb)
# run the hidden state through the MLP
h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1)
ctx_h = self.encoder(h).unsqueeze(0)
return ctx_h
评论列表
文章目录