model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, x):
        n_idx = 0
        c_idx = 1
        h_idx = 2
        w_idx = 3

        x = self.lookup_table(x)
        x = x.unsqueeze(c_idx)

        enc_outs = []
        for encoder in self.encoders:
            enc_ = F.relu(encoder(x))
            k_h = enc_.size()[h_idx]
            enc_ = F.max_pool2d(enc_, kernel_size=(k_h, 1))
            enc_ = enc_.squeeze(w_idx)
            enc_ = enc_.squeeze(h_idx)
            enc_outs.append(enc_)

        encoding = self.dropout(torch.cat(enc_outs, 1))
        return F.log_softmax(self.logistic(encoding))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号