def call(self, inputs, mask=None):
inputs_tensor = inputs
mask_inputs = K.expand_dims(mask)
inputs_shape = K.int_shape(inputs)
channel_axis = len(inputs_shape) - 1
if self.pool_mode == 'max':
mask_inv = tf.logical_not(mask_inputs)
negative_mask = K.cast(mask_inv, K.floatx()) * -1e20
negative_mask = K.repeat_elements(
negative_mask,
inputs_shape[channel_axis],
channel_axis
)
inputs_tensor = inputs + negative_mask
output = self.layer._pooling_function(
inputs_tensor,
self.layer.pool_size,
self.layer.strides,
self.layer.padding,
self.layer.data_format,
)
mask_inputs = K.cast(mask_inputs, K.floatx())
mask_output = self.layer._pooling_function(
mask_inputs,
self.layer.pool_size,
self.layer.strides,
self.layer.padding,
self.layer.data_format,
)
mask_output = K.repeat_elements(
mask_output,
inputs_shape[channel_axis],
channel_axis
)
return output * mask_output
评论列表
文章目录