19-mnist-profiling.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tensorflow-talk-debugging 作者: wookayin 项目源码 文件源码
def train(session):
    batch_size = 200
    session.run(tf.global_variables_initializer())
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)   # (*)
    run_metadata = tf.RunMetadata()

    # Training cycle
    for epoch in range(10):
        epoch_loss = 0.0
        batch_steps = mnist.train.num_examples / batch_size
        for step in range(batch_steps):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            _, c = session.run(
                [train_op, loss],
                feed_dict={x: batch_x, y: batch_y},
                options=run_options, run_metadata=run_metadata          # (*)
            )
            epoch_loss += c / batch_steps
        print "[%s] Epoch %02d, Loss = %.6f" % (datetime.now(), epoch, epoch_loss)

    # Dump profiling data (*)
    prof_timeline = tf.python.client.timeline.Timeline(run_metadata.step_stats)
    prof_ctf = prof_timeline.generate_chrome_trace_format()
    with open('./prof_ctf.json', 'w') as fp:
        print 'Dumped to prof_ctf.json'
        fp.write(prof_ctf)

    # Test model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print "Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号