combined_linear_classifier.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:scientific-paper-summarisation 作者: EdCo95 项目源码 文件源码
def graph():
    """
    Function to encapsulate the construction of a TensorFlow computation graph.
    :return: input placeholders, optimisation operation, loss, accuracy, prediction operations
    """

    # Define placeholders for the data

    # The sentence to classify, has shape [batch_size x word_dimensions*2] because the input will be the sentence
    # and abstract concatenated.
    sentence_input = tf.placeholder(tf.float32, shape=[None, WORD_DIMENSIONS + ABSTRACT_DIMENSION + NUM_FEATURES])

    # The labels for the sentences as one-hot vectors, of the form [batch_size x num_classes]
    labels = tf.placeholder(tf.float32, shape=[None, NUM_CLASSES])

    # Define the computation graph

    # The keep gate - decides which parts to keep
    keep_weight = weight_variable([WORD_DIMENSIONS + ABSTRACT_DIMENSION + NUM_FEATURES, NUM_CLASSES])
    keep_bias = bias_variable([NUM_CLASSES])
    output = tf.matmul(sentence_input, keep_weight) + keep_bias

    # Define the loss function
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, labels))
    opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

    # Predictions
    predictions = tf.nn.softmax(output)

    # Calculate accuracy
    pred_answers = tf.argmax(output, axis=1)
    correct_answers = tf.argmax(labels, axis=1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_answers, correct_answers), tf.float32))

    return sentence_input, labels, loss, opt, predictions, pred_answers, correct_answers, accuracy

# Construct the computation graph
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号