def train(session):
batch_size = 200
session.run(tf.global_variables_initializer())
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # (*)
run_metadata = tf.RunMetadata()
# Training cycle
for epoch in range(10):
epoch_loss = 0.0
batch_steps = mnist.train.num_examples / batch_size
for step in range(batch_steps):
batch_x, batch_y = mnist.train.next_batch(batch_size)
_, c = session.run(
[train_op, loss],
feed_dict={x: batch_x, y: batch_y},
options=run_options, run_metadata=run_metadata # (*)
)
epoch_loss += c / batch_steps
print "[%s] Epoch %02d, Loss = %.6f" % (datetime.now(), epoch, epoch_loss)
# Dump profiling data (*)
prof_timeline = tf.python.client.timeline.Timeline(run_metadata.step_stats)
prof_ctf = prof_timeline.generate_chrome_trace_format()
with open('./prof_ctf.json', 'w') as fp:
print 'Dumped to prof_ctf.json'
fp.write(prof_ctf)
# Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print "Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})
19-mnist-profiling.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录