def forward(self, input_, time):
self._check_input_dim(input_)
if time >= self.max_length:
time = self.max_length - 1
running_mean = getattr(self, 'running_mean_{}'.format(time))
running_var = getattr(self, 'running_var_{}'.format(time))
return functional.batch_norm(
input=input_, running_mean=running_mean, running_var=running_var,
weight=self.weight, bias=self.bias, training=self.training,
momentum=self.momentum, eps=self.eps)
评论列表
文章目录