main.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:simple-linear-regression 作者: williamd4112 项目源码 文件源码
def train_cross_validation(args, sess, model, phi_xs_train, ys_train):
    kf = KFold(n_splits=args.K)

    w_best = None
    validation_loss = 0

    for train_index, validation_index in kf.split(phi_xs_train):
        sess.run(tf.global_variables_initializer())

        model.fit(sess, phi_xs_train[train_index], ys_train[train_index], epoch=args.epoch, batch_size=args.batch_size)
        loss = model.eval(sess, phi_xs_train[validation_index], ys_train[validation_index])

        logging.info('Validation loss = %f' % (loss))
        validation_loss += loss

        model.reset(sess)

    return validation_loss / float(args.K)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号