batch_data_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tensorflow-deep-qa 作者: shuishen112 项目源码 文件源码
def batch_generator(filenames):
    """ filenames is the list of files you want to read from. 
    In this case, it contains only heart.csv
    """
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file
    _, value = reader.read(filename_queue)
    record_defaults = [[''] for _ in range(N_FEATURES)]
    # read in the 10 rows of data
    content = tf.decode_csv(value, record_defaults = record_defaults,field_delim = '\t') 


    # pack all 9 features into a tensor
    features = tf.stack(content[:N_FEATURES - 1])

    # assign the last column to label
    label = content[-1]

    # minimum number elements in the queue after a dequeue, used to ensure 
    # that the samples are sufficiently mixed
    # I think 10 times the BATCH_SIZE is sufficient
    min_after_dequeue = 10 * BATCH_SIZE

    # the maximum number of elements in the queue
    capacity = 20 * BATCH_SIZE

    # shuffle the data to generate BATCH_SIZE sample pairs
    data_batch, label_batch = tf.train.batch([features, label], batch_size=BATCH_SIZE, 
                                        capacity=capacity, min_after_dequeue = min_after_dequeue,
                                        allow_smaller_final_batch=True)

    return data_batch, label_batch
    # return features,label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号