model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:attention-over-attention-tf-QA 作者: lc222 项目源码 文件源码
def inference(documents, doc_mask, query, query_mask):

  embedding = tf.get_variable('embedding',
              [FLAGS.vocab_size, FLAGS.embedding_size],
              initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05))

  regularizer = tf.nn.l2_loss(embedding)

  doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob)
  doc_emb.set_shape([None, None, FLAGS.embedding_size])

  query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob)
  query_emb.set_shape([None, None, FLAGS.embedding_size])

  with tf.variable_scope('document', initializer=orthogonal_initializer()):
    fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
    back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)

    doc_len = tf.reduce_sum(doc_mask, reduction_indices=1)
    h, _ = tf.nn.bidirectional_dynamic_rnn(
        fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32)
    #h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
    h_doc = tf.concat(h, 2)

  with tf.variable_scope('query', initializer=orthogonal_initializer()):
    fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
    back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)

    query_len = tf.reduce_sum(query_mask, reduction_indices=1)
    h, _ = tf.nn.bidirectional_dynamic_rnn(
        fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32)
    #h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
    h_query = tf.concat(h, 2)

  M = tf.matmul(h_doc, h_query, adjoint_b=True)
  M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1)))

  alpha = softmax(M, 1, M_mask)
  beta = softmax(M, 2, M_mask)

  #query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1)
  query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1)

  s = tf.squeeze(tf.matmul(alpha, query_importance), [2])

  unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size))
  y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s])

  return y_hat, regularizer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号