def train_cross_validation(args, sess, model, phi_xs_train, ys_train):
kf = KFold(n_splits=args.K)
w_best = None
validation_loss = 0
for train_index, validation_index in kf.split(phi_xs_train):
sess.run(tf.global_variables_initializer())
model.fit(sess, phi_xs_train[train_index], ys_train[train_index], epoch=args.epoch, batch_size=args.batch_size)
loss = model.eval(sess, phi_xs_train[validation_index], ys_train[validation_index])
logging.info('Validation loss = %f' % (loss))
validation_loss += loss
model.reset(sess)
return validation_loss / float(args.K)
评论列表
文章目录