runCopyPostProblem.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:dizzy_layer 作者: Pastromhaug 项目源码 文件源码
def train_network(num_epochs, num_steps, state_size=4):
    sess.run(tf.initialize_all_variables())
    # print("---  min for  graph building ---",(time.time() - start_time)/60.0)
    # start_time = time.time()
    training_losses = []

    #  X_test, Y_test = genTestData(num_steps, num_test_runs, num_classes)
    X_test, Y_test = getTestData()

    for idx, (X_epoch,Y_epoch) in enumerate(genEpochs(num_epochs, num_data_points, num_steps, batch_size, num_classes)):
        training_loss = 0
        acc = 0
        num_batches = 0
        training_state = [np.zeros((batch_size, state_size)) for i in range(num_stacked)]

        print("EPOCH %d" % idx)
        for batch in range(len(X_epoch)):
            X = X_epoch[batch]
            Y = Y_epoch[batch]

            (train_step_, loss_, train_summary_) = sess.run([train_step, loss, train_summary],
                              feed_dict={x:X, y:Y},
                              options=run_options, run_metadata=run_metadata)

            training_loss += loss_
            train_writer.add_summary(train_summary_, idx)
            num_batches += 1

        (test_loss, test_summary_, accuracy_) = sess.run(
            [loss, test_summary, accuracy],
            feed_dict={x:X_test, y:Y_test},
            options=run_options, run_metadata=run_metadata)
        train_writer.add_summary(test_summary_, idx)

        training_loss = training_loss/num_batches
        print("train loss:", training_loss, "test loss:", test_loss, "test accuracy:", accuracy_)
        training_loss = 0

    tl = timeline.Timeline(run_metadata.step_stats)
    ctf = tl.generate_chrome_trace_format()
    with open('timeline_add.json', 'w') as f:
        f.write(ctf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号