mask_pooling.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:yoctol-keras-layer-zoo 作者: Yoctol 项目源码 文件源码
def call(self, inputs, mask=None):
        inputs_tensor = inputs
        mask_inputs = K.expand_dims(mask)

        inputs_shape = K.int_shape(inputs)
        channel_axis = len(inputs_shape) - 1

        if self.pool_mode == 'max':
            mask_inv = tf.logical_not(mask_inputs)
            negative_mask = K.cast(mask_inv, K.floatx()) * -1e20
            negative_mask = K.repeat_elements(
                negative_mask,
                inputs_shape[channel_axis],
                channel_axis
            )
            inputs_tensor = inputs + negative_mask

        output = self.layer._pooling_function(
            inputs_tensor,
            self.layer.pool_size,
            self.layer.strides,
            self.layer.padding,
            self.layer.data_format,
        )
        mask_inputs = K.cast(mask_inputs, K.floatx())

        mask_output = self.layer._pooling_function(
            mask_inputs,
            self.layer.pool_size,
            self.layer.strides,
            self.layer.padding,
            self.layer.data_format,
        )
        mask_output = K.repeat_elements(
            mask_output,
            inputs_shape[channel_axis],
            channel_axis
        )
        return output * mask_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号