linear_regression.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dockerfiles 作者: floydhub 项目源码 文件源码
def linear_train(train_data, train_target, n_epochs=200):
    for _ in range(n_epochs):
        # Get the result of the forward pass.
        output = linear_forward(train_data)

        # Calculate the loss between the training data and target data.
        loss = F.mean_squared_error(train_target, output)

        # Zero all gradients before updating them.
        linear_function.zerograds()

        # Calculate and update all gradients.
        loss.backward()

        # Use the optmizer to move all parameters of the network
        # to values which will reduce the loss.
        optimizer.update()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号