def main(FLAG):
Model = SimpleModel(FLAG.input_dim, FLAG.hidden_dim, FLAG.output_dim, optimizer=tf.train.RMSPropOptimizer(FLAG.learning_rate))
image, label = load_dataset()
image, label = image_augmentation(image, label, horizon_flip=True, control_brightness=True)
label = label / 96.
(train_X, train_y), (valid_X, valid_y), (test_X, test_y) = split_data(image, label)
if FLAG.Mode == "validation":
lr_list = 10 ** np.random.uniform(-6, -2, 20)
Model.validation(train_X, train_y, valid_X, valid_y, lr_list)
elif FLAG.Mode == "train":
Model.train(train_X, train_y, valid_X, valid_y, FLAG.batch_size, FLAG.Epoch, FLAG.save_graph, FLAG.save_model)
pred_Y = Model.predict(test_X[123])
print(pred_Y)
print(test_y[123])
print(np.mean(np.square( pred_Y - test_y[123] )))
评论列表
文章目录