train_lite.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chn_handwriting 作者: zhangchunsheng 项目源码 文件源码
def train_hand_write_cnn():
    output = chinese_hand_write_cnn()

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, Y))
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)), tf.float32))

    # TensorBoard
    tf.scalar_summary("loss", loss)
    tf.scalar_summary("accuracy", accuracy)
    merged_summary_op = tf.merge_all_summaries()

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # ????? tensorboard --logdir=./log  ???????http://0.0.0.0:6006
        summary_writer = tf.train.SummaryWriter('./log', graph=tf.get_default_graph())

        for e in range(50):
            for i in range(num_batch):
                batch_x = train_data_x[i*batch_size : (i+1)*batch_size]
                batch_y = train_data_y[i*batch_size : (i+1)*batch_size]
                _, loss_, summary = sess.run([optimizer, loss, merged_summary_op], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.5})
                # ?????????
                summary_writer.add_summary(summary, e*num_batch+i)
                print(e*num_batch+i, loss_)

                if (e*num_batch+i) % 100 == 0:
                    # ?????
                    acc = accuracy.eval({X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
                    #acc = sess.run(accuracy, feed_dict={X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
                    print(e*num_batch+i, acc)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号