data_loader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DL2W 作者: gauravmm 项目源码 文件源码
def decode(filename_queue):
    # Create TFRecords reader
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # Feature keys in TFRecords example
    features = tf.parse_single_example(serialized_example, features={
        'id': tf.FixedLenFeature([], tf.string),
        'vector': tf.FixedLenFeature([], tf.string),
        'label': tf.VarLenFeature(tf.int64)
    })

    video_id = features['id']

    # Decode vector and pad to fixed size
    vector = tf.decode_raw(features['vector'], tf.float32)
    vector = tf.reshape(vector, [-1, 300])
    vector = tf.pad(vector, [[0, 40 - tf.shape(vector)[0]], [0, 0]])
    vector.set_shape([40, 300])

    # Get label index
    label = tf.sparse_to_indicator(features['label'], 4716)
    label.set_shape([4716])
    label = tf.cast(label, tf.float32)

    return video_id, vector, label

# Creates input pipeline for tensorflow networks
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号