train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-dfi 作者: dsanno 项目源码 文件源码
def update(net, optimizer, link, target_layers, tv_weight=0.001):
    layers = feature(net, link.x)
    total_loss = 0
    losses = []
    for layer, target in zip(layers, target_layers):
        loss = F.mean_squared_error(layer, target)
        losses.append(float(loss.data))
        total_loss += loss
    tv_loss = tv_weight * total_variation(link.x)
    losses.append(float(tv_loss.data))
    total_loss += tv_loss
    link.cleargrads()
    total_loss.backward()
    optimizer.update()
    return losses
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号