nets.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:e2e-model-learning 作者: locuslab 项目源码 文件源码
def run_rmse_net(model, variables, X_train, Y_train):
    opt = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(1000):
        opt.zero_grad()
        model.train()
        train_loss = nn.MSELoss()(
            model(variables['X_train_'])[0], variables['Y_train_'])
        train_loss.backward()
        opt.step()

        model.eval()
        test_loss = nn.MSELoss()(
            model(variables['X_test_'])[0], variables['Y_test_'])

        print(i, train_loss.data[0], test_loss.data[0])

    model.eval()
    model.set_sig(variables['X_train_'], variables['Y_train_'])

    return model


# TODO: minibatching
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号