def graph():
"""
Function to encapsulate the construction of a TensorFlow computation graph.
:return: input placeholders, optimisation operation, loss, accuracy, prediction operations
"""
# Define placeholders for the data
# The sentence to classify, has shape [batch_size x word_dimensions*2] because the input will be the sentence
# and abstract concatenated.
sentence_input = tf.placeholder(tf.float32, shape=[None, WORD_DIMENSIONS + ABSTRACT_DIMENSION + NUM_FEATURES])
# The labels for the sentences as one-hot vectors, of the form [batch_size x num_classes]
labels = tf.placeholder(tf.float32, shape=[None, NUM_CLASSES])
# Define the computation graph
# The keep gate - decides which parts to keep
keep_weight = weight_variable([WORD_DIMENSIONS + ABSTRACT_DIMENSION + NUM_FEATURES, NUM_CLASSES])
keep_bias = bias_variable([NUM_CLASSES])
output = tf.matmul(sentence_input, keep_weight) + keep_bias
# Define the loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, labels))
opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
# Predictions
predictions = tf.nn.softmax(output)
# Calculate accuracy
pred_answers = tf.argmax(output, axis=1)
correct_answers = tf.argmax(labels, axis=1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_answers, correct_answers), tf.float32))
return sentence_input, labels, loss, opt, predictions, pred_answers, correct_answers, accuracy
# Construct the computation graph
combined_linear_classifier.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录