train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:real_time_face_detection 作者: Snowapril 项目源码 文件源码
def main(FLAG):
    Model = SimpleModel(FLAG.input_dim, FLAG.hidden_dim, FLAG.output_dim, optimizer=tf.train.RMSPropOptimizer(FLAG.learning_rate))

    image, label = load_dataset()
    image, label = image_augmentation(image, label, horizon_flip=True, control_brightness=True)
    label = label / 96.
    (train_X, train_y), (valid_X, valid_y), (test_X, test_y) = split_data(image, label)

    if FLAG.Mode == "validation":
        lr_list = 10 ** np.random.uniform(-6, -2, 20)
        Model.validation(train_X, train_y, valid_X, valid_y, lr_list)
    elif FLAG.Mode == "train":
        Model.train(train_X, train_y, valid_X, valid_y, FLAG.batch_size, FLAG.Epoch, FLAG.save_graph, FLAG.save_model)

        pred_Y = Model.predict(test_X[123])
        print(pred_Y)
        print(test_y[123])
        print(np.mean(np.square( pred_Y - test_y[123] )))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号