def get_batch_data():
# Load data
X, Y = load_data()
# calc total batch count
num_batch = len(X) // hp.batch_size
# Convert to tensor
X = tf.convert_to_tensor(X, tf.int32)
Y = tf.convert_to_tensor(Y, tf.float32)
# Create Queues
input_queues = tf.train.slice_input_producer([X, Y])
# create batch queues
x, y = tf.train.batch(input_queues,
num_threads=8,
batch_size=hp.batch_size,
capacity=hp.batch_size * 64,
allow_smaller_final_batch=False)
return x, y, num_batch # (N, T), (N, T), ()
评论列表
文章目录