def __init__(self, z_dim, transition_dim):
super(GatedTransition, self).__init__()
# initialize the six linear transformations used in the neural network
self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim)
self.lin_gate_hidden_to_z = nn.Linear(transition_dim, z_dim)
self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim)
self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim, z_dim)
self.lin_sig = nn.Linear(z_dim, z_dim)
self.lin_z_to_mu = nn.Linear(z_dim, z_dim)
# modify the default initialization of lin_z_to_mu
# so that it's starts out as the identity function
self.lin_z_to_mu.weight.data = torch.eye(z_dim)
self.lin_z_to_mu.bias.data = torch.zeros(z_dim)
# initialize the three non-linearities used in the neural network
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.softplus = nn.Softplus()
评论列表
文章目录