utilities.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:unsupervised-2017-cvprw 作者: imatge-upc 项目源码 文件源码
def read_my_file_format_dis(filename_queue, is_training):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    context_features = {
        "height": tf.FixedLenFeature([], dtype=tf.int64),
        "width": tf.FixedLenFeature([], dtype=tf.int64),
        "sequence_length": tf.FixedLenFeature([], dtype=tf.int64),
        "text": tf.FixedLenFeature([], dtype=tf.string),
        "label": tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "frames": tf.FixedLenSequenceFeature([], dtype=tf.string),
        "masks": tf.FixedLenSequenceFeature([], dtype=tf.string)
    }
    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    height = 128#context_parsed['height'].eval()
    width  = 128#context_parsed['width'].eval()
    sequence_length = 32#context_parsed['sequence_length'].eval()

    clip  = decode_frames(sequence_parsed['frames'], height, width, sequence_length)

    # generate one hot vector
    label = context_parsed['label']
    label = tf.one_hot(label-1, FLAGS.num_class)
    text  = context_parsed['text']

    # randomly sample clips of 16 frames
    if is_training:
        idx = tf.squeeze(tf.random_uniform([1], 0, sequence_length-FLAGS.seq_length+1, dtype=tf.int32))
    else:
        idx = 8
    clip = clip[idx:idx+FLAGS.seq_length] / 255.0 * 2 - 1

    if is_training:
        # randomly reverse data
        reverse   = tf.squeeze(tf.random_uniform([1], 0, 2, dtype=tf.int32))
        clip      = tf.cond(tf.equal(reverse,0), lambda: clip, lambda: clip[::-1])

        # randomly horizontally flip data
        flip      = tf.squeeze(tf.random_uniform([1], 0, 2, dtype=tf.int32))
        clip      = tf.cond(tf.equal(flip,0), lambda: clip, lambda: \
                            tf.map_fn(lambda img: tf.image.flip_left_right(img), clip))

    clip.set_shape([FLAGS.seq_length, height, width, 3])

    return clip, label, text
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号