def read(self, shuffle=True, num_epochs=None):
with tf.name_scope('input'):
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer([self.filename], num_epochs=num_epochs)
_, serialized_input = reader.read(filename_queue)
inputs = tf.parse_single_example(serialized_input,
features={
'inputs_seq': tf.FixedLenFeature([self.seq_len * 2 + 3], tf.int64),
'output': tf.FixedLenFeature([1], tf.int64)
})
inputs_seq = inputs['inputs_seq']
output = inputs['output']
min_after_dequeue = 100
if shuffle:
inputs_seqs, outputs = tf.train.shuffle_batch([inputs_seq, output], batch_size=self.batch_size, num_threads=2, capacity=min_after_dequeue + 3 * self.batch_size, min_after_dequeue=min_after_dequeue)
else:
inputs_seqs, outputs = tf.train.batch([inputs_seq, output], batch_size=self.batch_size)
return inputs_seqs, outputs
评论列表
文章目录