sparse_filtering.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:spykes 作者: KordingLab 项目源码 文件源码
def sparse_filtering_loss(_, y_pred):
    '''Defines the sparse filtering loss function.

    Args:
        y_true (tensor): The ground truth tensor (not used, since this is an
            unsupervised learning algorithm).
        y_pred (tensor): Tensor representing the feature vector at a
            particular layer.

    Returns:
        scalar tensor: The sparse filtering loss.
    '''
    y = tf.reshape(y_pred, tf.stack([-1, tf.reduce_prod(y_pred.shape[1:])]))
    l2_normed = tf.nn.l2_normalize(y, dim=1)
    l1_norm = tf.norm(l2_normed, ord=1, axis=1)
    return tf.reduce_sum(l1_norm)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号