data_generator.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:SSD_tensorflow_VOC 作者: LevinJ 项目源码 文件源码
def train_save_model():
    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)

        dataset = flowers.get_split('train', flowers_data_dir)
        images, _, labels = load_batch(dataset)

        # Create the model:
        logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)

        # Specify the loss function:
        one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
        slim.losses.softmax_cross_entropy(logits, one_hot_labels)
        total_loss = slim.losses.get_total_loss()

        # Create some summaries to visualize the training process:
        tf.summary.scalar('losses/Total Loss', total_loss)

        # Specify the optimizer and create the train op:
        optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        # Run the training:
        final_loss = slim.learning.train(
          train_op,
          logdir=train_dir,
          number_of_steps=1, # For speed, we just do 1 epoch
          save_summaries_secs=1)

        print('Finished training. Final batch loss %d' % final_loss)
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号