controller.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Facial_KeyPoints_Detection 作者: wadhwasahil 项目源码 文件源码
def linear_NN(X, y):
    graph = tf.Graph()
    with graph.as_default():
        nn = linear_nn.nn_linear(X, y)
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=0.001,
            momentum=0.9,
            use_nesterov=True,
        ).minimize(nn.loss, global_step=global_step)
        with tf.Session(graph=graph) as session:
            train_loss_history = []
            session.run(tf.global_variables_initializer())
            batches = data_helpers.batch_iter(zip(X, y), batch_size=64, num_epochs=num_epochs, shuffle=True)
            for batch in batches:
                X_train, y_train = zip(*batch)
                feed_dict = {nn.input_x: np.asarray(X_train), nn.input_y: np.asarray(y_train)}
                _, step, loss, predictions = session.run([optimizer, global_step, nn.loss, nn.predictions], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{}: step {}, loss {:g}".format(time_str, step, loss))
                train_loss_history.append(loss)
                # if step % 10 == 0:
                #     pass
            x_axis = np.arange(step)
            plt.plot(x_axis, train_loss_history, "b-", linewidth=2, label="train")
            plt.grid()
            plt.legend()
            plt.ylabel("loss")
            plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号