def forward(self, input):
if isinstance(input, Variable):
return F.max_pool2d(input, self.kernel_size, self.stride, \
self.padding, self.dilation, self.ceil_mode, \
self.return_indices)
elif isinstance(input, tuple) or isinstance(input, list):
return my_data_parallel(self, input)
else:
raise RuntimeError('unknown input type')
评论列表
文章目录