test_loss_masking.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:keras 作者: NVIDIA 项目源码 文件源码
def test_loss_masking():
    weighted_loss = weighted_objective(objectives.get('mae'))
    shape = (3, 4, 2)
    X = np.arange(24).reshape(shape)
    Y = 2 * X

    # Normally the trailing 1 is added by standardize_weights
    weights = np.ones((3,))
    mask = np.ones((3, 4))
    mask[1, 0] = 0

    out = K.eval(weighted_loss(K.variable(X),
                               K.variable(Y),
                               K.variable(weights),
                               K.variable(mask)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号