blog_dset_estmtrs.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:TFExperiments 作者: gnperdue 项目源码 文件源码
def my_input(file_path, perform_shuffle=False, repeat_count=1):
    """
    create an input function reading a file with the Dataset API
    """
    def decode_csv(line):
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        label = parsed_line[-1:]
        del parsed_line[-1]
        features = parsed_line
        d = dict(zip(feature_names, features)), label
        return d

    dataset = (tf.data.TextLineDataset(file_path).skip(1).map(decode_csv))
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)
    dataset = dataset.repeat(repeat_count)
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号