Deopen_regression.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Deopen 作者: kimmo1019 项目源码 文件源码
def model_train(X_train, y_train,learning_rate = 1e-4,epochs = 50):
    network = create_network()
    lr = theano.shared(np.float32(learning_rate))
    net = NeuralNet(
                network,
                max_epochs=epochs,
                update=adam,
                update_learning_rate=lr,
                train_split=TrainSplit(eval_size=0.1),
                batch_iterator_train=BatchIterator(batch_size=32),
                batch_iterator_test=BatchIterator(batch_size=64),
                regression = True,
                objective_loss_function = squared_error,
                #on_training_started=[LoadBestParam(iteration=val_loss.argmin())],
                on_epoch_finished=[EarlyStopping(patience=5)],
                verbose=1)
    print 'loading pre-training weights...'
    net.load_params_from(params[val_loss.argmin()])
    print 'continue to train...'
    net.fit(X_train, y_train)
    print 'training finished'
    return net

#model testing
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号