modules.py 文件源码

python
阅读 54 收藏 0 点赞 0 评论 0

项目:end-to-end-negotiator 作者: facebookresearch 项目源码 文件源码
def forward(self, ctx):
        idx = np.arange(ctx.size(0) // 2)
        # extract counts and values
        cnt_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 0)))
        val_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 1)))

        cnt = ctx.index_select(0, cnt_idx)
        val = ctx.index_select(0, val_idx)

        # embed counts and values
        cnt_emb = self.cnt_enc(cnt)
        val_emb = self.val_enc(val)

        # element wise multiplication to get a hidden state
        h = torch.mul(cnt_emb, val_emb)
        # run the hidden state through the MLP
        h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1)
        ctx_h = self.encoder(h).unsqueeze(0)
        return ctx_h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号