model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:attention-over-attention-tf-QA 作者: lc222 项目源码 文件源码
def train(y_hat, regularizer, document, doc_weight, answer):
  # Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206
  # This unfortunately causes us to expand a sparse tensor into the full vocabulary
  index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer)
  flat = tf.reshape(y_hat, [-1])
  relevant = tf.gather(flat, index)

  # mean cause reg is independent of batch size
  loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer 

  global_step = tf.Variable(0, name="global_step", trainable=False)

  accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer)))

  optimizer = tf.train.AdamOptimizer()
  grads_and_vars = optimizer.compute_gradients(loss)
  capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars]
  train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step)

  tf.summary.scalar('loss', loss)
  tf.summary.scalar('accuracy', accuracy)
  return loss, train_op, global_step, accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号