training.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pythonLean 作者: 527515025 项目源码 文件源码
def run_training():  

    # ???
    train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/img/'   #My dir--20170727-csq  
    #logs_train_dir ??????????????tensorboard ??? 
    logs_train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/saveNet/'  

    # ????????
    train, train_label = input_data.get_files(train_dir)  
    # ????
    train_batch, train_label_batch = input_data.get_batch(train,  
                                                          train_label,  
                                                          IMG_W,  
                                                          IMG_H,  
                                                          BATCH_SIZE,   
                                                          CAPACITY)
    # ????
    train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES) 
    # ?? loss 
    train_loss = model.losses(train_logits, train_label_batch)
    # ?? 
    train_op = model.trainning(train_loss, learning_rate)
    # ????? 
    train__acc = model.evaluation(train_logits, train_label_batch)  
    # ?? summary
    summary_op = tf.summary.merge_all()  
    sess = tf.Session()
    # ??summary
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)  
    saver = tf.train.Saver()  

    sess.run(tf.global_variables_initializer())  
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)  

    try:  
        for step in np.arange(MAX_STEP):  
            if coord.should_stop():  
                    break  
            _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])  

            if step % 50 == 0:  
                print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0))  
                summary_str = sess.run(summary_op)  
                train_writer.add_summary(summary_str, step)  

            if step % 2000 == 0 or (step + 1) == MAX_STEP:  
                # ??2000????????????? checkpoint_path ?
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')  
                saver.save(sess, checkpoint_path, global_step=step)  

    except tf.errors.OutOfRangeError:  
        print('Done training -- epoch limit reached')  
    finally:  
        coord.request_stop()
    coord.join(threads)  
    sess.close()  

# train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号