def read_records(index=0):
train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)
queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'document': tf.VarLenFeature(tf.int64),
'query': tf.VarLenFeature(tf.int64),
'answer': tf.FixedLenFeature([], tf.int64)
})
document = sparse_ops.serialize_sparse(features['document'])
query = sparse_ops.serialize_sparse(features['query'])
answer = features['answer']
document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
[document, query, answer], batch_size=FLAGS.batch_size,
capacity=2000,
min_after_dequeue=1000)
sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)
document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1)
query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1)
return document_batch, document_weights, query_batch, query_weights, answer_batch
评论列表
文章目录