data_load.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:neurobind 作者: Kyubyong 项目源码 文件源码
def get_batch_data():
    # Load data
    X, Y = load_data()

    # calc total batch count
    num_batch = len(X) // hp.batch_size

    # Convert to tensor
    X = tf.convert_to_tensor(X, tf.int32)
    Y = tf.convert_to_tensor(Y, tf.float32)

    # Create Queues
    input_queues = tf.train.slice_input_producer([X, Y])

    # create batch queues
    x, y = tf.train.batch(input_queues,
                          num_threads=8,
                          batch_size=hp.batch_size,
                          capacity=hp.batch_size * 64,
                          allow_smaller_final_batch=False)

    return x, y, num_batch  # (N, T), (N, T), ()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号