def __init__(self, d_in, d_out, use_cuda):
super(Decoder, self).__init__()
self.d_in = d_in
self.d_out = d_out
self.use_cuda = use_cuda
if self.use_cuda:
self.a1 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a2 = Parameter(1. * torch.ones(1, d_in).cuda())
self.a3 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a4 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a5 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a6 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a7 = Parameter(1. * torch.ones(1, d_in).cuda())
self.a8 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a9 = Parameter(0. * torch.ones(1, d_in).cuda())
self.a10 = Parameter(0. * torch.ones(1, d_in).cuda())
else:
self.a1 = Parameter(0. * torch.ones(1, d_in))
self.a2 = Parameter(1. * torch.ones(1, d_in))
self.a3 = Parameter(0. * torch.ones(1, d_in))
self.a4 = Parameter(0. * torch.ones(1, d_in))
self.a5 = Parameter(0. * torch.ones(1, d_in))
self.a6 = Parameter(0. * torch.ones(1, d_in))
self.a7 = Parameter(1. * torch.ones(1, d_in))
self.a8 = Parameter(0. * torch.ones(1, d_in))
self.a9 = Parameter(0. * torch.ones(1, d_in))
self.a10 = Parameter(0. * torch.ones(1, d_in))
if self.d_out is not None:
self.V = torch.nn.Linear(d_in, d_out, bias=False)
self.V.weight.data = torch.randn(self.V.weight.data.size()) / np.sqrt(d_in)
# batch-normalization for u
self.bn_normalize = torch.nn.BatchNorm1d(d_out, affine=False)
# buffer for hat_z_l to be used for cost calculation
self.buffer_hat_z_l = None
评论列表
文章目录