model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:attention-over-attention-tf-QA 作者: lc222 项目源码 文件源码
def read_records(index=0):
  train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
  validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
  test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)

  queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(queue)
  features = tf.parse_single_example(
      serialized_example,
      features={
        'document': tf.VarLenFeature(tf.int64),
        'query': tf.VarLenFeature(tf.int64),
        'answer': tf.FixedLenFeature([], tf.int64)
      })

  document = sparse_ops.serialize_sparse(features['document'])
  query = sparse_ops.serialize_sparse(features['query'])
  answer = features['answer']

  document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
      [document, query, answer], batch_size=FLAGS.batch_size,
      capacity=2000,
      min_after_dequeue=1000)

  sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
  sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)

  document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
  document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1)

  query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
  query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1)

  return document_batch, document_weights, query_batch, query_weights, answer_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号