runAdditionProblem.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:dizzy_layer 作者: Pastromhaug 项目源码 文件源码
def train_network(num_epochs, num_steps, state_size=4):
    # with tf.Session() as sess:


    sess.run(tf.initialize_all_variables())
    # print("---  min for  graph building ---",(time.time() - start_time)/60.0)
    # start_time = time.time()
    training_losses = []

    # (test_X_epoch,test_Y_epoch) = genData(num_data_points, num_steps, batch_size)
    test_X_epoch,test_Y_epoch = getTestData()


    for idx, (X_epoch,Y_epoch) in enumerate(genEpochs(num_epochs, num_data_points, num_steps, batch_size)):

        training_loss = 0
        num_batches = 0
        print("EPOCH %d" % idx)
        for batch in range(len(X_epoch)):
            X = X_epoch[batch]
            Y = Y_epoch[batch]
            (train_step_, loss_, summary_, prediction_) = sess.run([train_step, loss, summary, prediction],
                              feed_dict={x:X, y:Y},
                              options=run_options, run_metadata=run_metadata)

            training_loss += loss_
            train_writer.add_summary(summary_, idx)
            num_batches += 1

        test_loss = 0
        test_num_batches = 0
        for test_batch in range(len(test_X_epoch)):
            X_test = test_X_epoch[test_batch]
            Y_test = test_Y_epoch[test_batch]
            (test_loss_, test_loss_summary_) = sess.run([loss, test_loss_summary],
                feed_dict={x:X_test, y:Y_test},
                options=run_options, run_metadata=run_metadata)
            test_loss += test_loss_
            train_writer.add_summary(test_loss_summary_, idx)
            test_num_batches += 1

        test_loss = test_loss/test_num_batches
        training_loss = training_loss/num_batches
        print("train loss:", training_loss, "test loss", test_loss)
        training_loss = 0
        test_loss = 0

    # tl = timeline.Timeline(run_metadata.step_stats)
    # ctf = tl.generate_chrome_trace_format()
    # with open('./timelines/additionV2.json', 'w') as f:
    #     f.write(ctf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号