read_rec.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Stock-Predict-RNN 作者: daiab 项目源码 文件源码
def read_and_decode(record_file):
    print(record_file)
    # read_and_decode_test(record_file)
    data_queue = tf.train.input_producer([record_file], capacity=1e5, name="string_input_producer")
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(data_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={'label': tf.FixedLenFeature([], tf.int64),
                  'target': tf.FixedLenFeature([], tf.float32),
                  'data': tf.FixedLenFeature([cfg.time_step * 4], tf.float32)})
    data_raw = features['data']
    label = features['label']
    target = features['target']
    data = tf.reshape(data_raw, [cfg.time_step, 4])
    data.set_shape([cfg.time_step, 4])
    if cfg.is_training:
        data_batch, label_batch, target_batch = tf.train.batch([data, label, target],
                                                     batch_size=cfg.batch_size,
                                                     capacity=cfg.batch_size * 50,
                                                     num_threads=4)
        return data_batch, label_batch, target_batch
    else:
        return tf.expand_dims(data, 0), tf.expand_dims(label, 0), tf.expand_dims(target, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号