def train(y_hat, regularizer, document, doc_weight, answer):
# Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206
# This unfortunately causes us to expand a sparse tensor into the full vocabulary
index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer)
flat = tf.reshape(y_hat, [-1])
relevant = tf.gather(flat, index)
# mean cause reg is independent of batch size
loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer
global_step = tf.Variable(0, name="global_step", trainable=False)
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer)))
optimizer = tf.train.AdamOptimizer()
grads_and_vars = optimizer.compute_gradients(loss)
capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars]
train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step)
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
return loss, train_op, global_step, accuracy
评论列表
文章目录