transfer_cifar10_softmax_b1.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deligan 作者: val-iisc 项目源码 文件源码
def add_final_training_ops(graph, class_count, final_tensor_name,
                           ground_truth_tensor_name):
    """Adds a new softmax and fully-connected layer for training.
    We need to retrain the top layer to identify our new classes, so this function
    adds the right operations to the graph, along with some variables to hold the
    weights, and then sets up all the gradients for the backward pass.
    The set up for the softmax and fully-connected layers is based on:
    https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
    Args:
      graph: Container for the existing model's Graph.
      class_count: Integer of how many categories of things we're trying to
      recognize.
      final_tensor_name: Name string for the new final node that produces results.
      ground_truth_tensor_name: Name string of the node we feed ground truth data
      into.
    Returns:
      Nothing.
    """
    bottleneck_tensor = graph.get_tensor_by_name(ensure_name_has_port(
        BOTTLENECK_TENSOR_NAME))
    layer_weights = tf.Variable(
        tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001),
        name='final_weights')
    layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
    logits = tf.matmul(bottleneck_tensor, layer_weights,
                       name='final_matmul') + layer_biases
    tf.nn.softmax(logits, name=final_tensor_name)
    ground_truth_placeholder = tf.placeholder(tf.float32,
                                              [None, class_count],
                                              name=ground_truth_tensor_name)
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
        logits, ground_truth_placeholder)
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
        cross_entropy_mean)
    return train_step, cross_entropy_mean

# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号