decoder.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ladder 作者: abhiskk 项目源码 文件源码
def __init__(self, d_in, d_out, use_cuda):
        super(Decoder, self).__init__()

        self.d_in = d_in
        self.d_out = d_out
        self.use_cuda = use_cuda

        if self.use_cuda:
            self.a1 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a2 = Parameter(1. * torch.ones(1, d_in).cuda())
            self.a3 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a4 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a5 = Parameter(0. * torch.ones(1, d_in).cuda())

            self.a6 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a7 = Parameter(1. * torch.ones(1, d_in).cuda())
            self.a8 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a9 = Parameter(0. * torch.ones(1, d_in).cuda())
            self.a10 = Parameter(0. * torch.ones(1, d_in).cuda())
        else:
            self.a1 = Parameter(0. * torch.ones(1, d_in))
            self.a2 = Parameter(1. * torch.ones(1, d_in))
            self.a3 = Parameter(0. * torch.ones(1, d_in))
            self.a4 = Parameter(0. * torch.ones(1, d_in))
            self.a5 = Parameter(0. * torch.ones(1, d_in))

            self.a6 = Parameter(0. * torch.ones(1, d_in))
            self.a7 = Parameter(1. * torch.ones(1, d_in))
            self.a8 = Parameter(0. * torch.ones(1, d_in))
            self.a9 = Parameter(0. * torch.ones(1, d_in))
            self.a10 = Parameter(0. * torch.ones(1, d_in))


        if self.d_out is not None:
            self.V = torch.nn.Linear(d_in, d_out, bias=False)
            self.V.weight.data = torch.randn(self.V.weight.data.size()) / np.sqrt(d_in)
            # batch-normalization for u
            self.bn_normalize = torch.nn.BatchNorm1d(d_out, affine=False)

        # buffer for hat_z_l to be used for cost calculation
        self.buffer_hat_z_l = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号