def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
output = self.model(input)
if self.learn_residual:
output = input + output
output = torch.clamp(output,min = -1,max = 1)
return output
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
评论列表
文章目录