def train_hand_write_cnn():
output = chinese_hand_write_cnn()
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)), tf.float32))
# TensorBoard
tf.scalar_summary("loss", loss)
tf.scalar_summary("accuracy", accuracy)
merged_summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ????? tensorboard --logdir=./log ???????http://0.0.0.0:6006
summary_writer = tf.train.SummaryWriter('./log', graph=tf.get_default_graph())
for e in range(50):
for i in range(num_batch):
batch_x = train_data_x[i*batch_size : (i+1)*batch_size]
batch_y = train_data_y[i*batch_size : (i+1)*batch_size]
_, loss_, summary = sess.run([optimizer, loss, merged_summary_op], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.5})
# ?????????
summary_writer.add_summary(summary, e*num_batch+i)
print(e*num_batch+i, loss_)
if (e*num_batch+i) % 100 == 0:
# ?????
acc = accuracy.eval({X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
#acc = sess.run(accuracy, feed_dict={X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
print(e*num_batch+i, acc)
评论列表
文章目录