def forward_one_step(self, x, y, test=False, apply_f=True):
f = activations[self.activation_function]
if self.apply_batchnorm_to_input:
if self.batchnorm_before_activation:
merged_input = f(self.batchnorm_merge(self.layer_merge_x(x) + self.layer_merge_y(y), test=test))
else:
merged_input = f(self.layer_merge_x(self.batchnorm_merge(x, test=test)) + self.layer_merge_y(y))
else:
merged_input = f(self.layer_merge_x(x) + self.layer_merge_y(y))
chain = [merged_input]
# Hidden
for i in range(self.n_layers):
u = chain[-1]
if self.batchnorm_before_activation:
u = getattr(self, "layer_%i" % i)(u)
if self.apply_batchnorm:
u = getattr(self, "batchnorm_%d" % i)(u, test=test)
if self.batchnorm_before_activation == False:
u = getattr(self, "layer_%i" % i)(u)
output = f(u)
if self.apply_dropout:
output = F.dropout(output, train=not test)
chain.append(output)
u = chain[-1]
mean = self.layer_output_mean(u)
# log(sd^2)
u = chain[-1]
ln_var = self.layer_output_var(u)
return mean, ln_var
评论列表
文章目录