layers.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:dcgan 作者: zsdonghao 项目源码 文件源码
def target_mask_op(data, pad_val=0):        # HangSheng: return tensor for mask,if input is tf.string
    data_shape_size = data.get_shape().ndims
    if data_shape_size == 3:
        return tf.cast(tf.reduce_any(tf.not_equal(data, pad_val), axis=2), dtype=tf.int32)
    elif data_shape_size == 2:
        return tf.cast(tf.not_equal(data, pad_val), dtype=tf.int32)
    elif data_shape_size == 1:
        raise ValueError("target_mask_op: data has wrong shape!")
    else:
        raise ValueError("target_mask_op: handling data_shape_size %s hasn't been implemented!" % (data_shape_size))


# Dynamic RNN
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号