dict_matcher.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:AuthoringDecompositions 作者: jrock08 项目源码 文件源码
def match_to_dict_conv(image_as_patches, dictionary, include_counts=False):
    print 'match_to_dict_conv'
    [n,w,h,c] = dictionary.get_shape().as_list()
    #dict_as_filt = tf.transpose(tf.reshape(dictionary, [-1, w*h*c,1,1]))
    dict_as_filt = tf.transpose(tf.reshape(dictionary, [-1, w*h*c]))
    print dict_as_filt.get_shape()

    [n,w,h,c] = image_as_patches.get_shape().as_list()
    #image_flattened = tf.reshape(image_as_patches, [-1,1,1,w*h*c])
    image_flattened = tf.reshape(image_as_patches, [-1,w*h*c])
    print image_flattened.get_shape()

    #pair_dist = -2 * tf.reshape(tf.nn.conv2d(image_flattened, dict_as_filt, [1,1,1,1], 'SAME'), [n, -1])
    pair_dist = -2 * tf.matmul(image_flattened, dict_as_filt)
    print pair_dist.get_shape()

    single_dist = tf.reduce_sum(tf.square(dictionary),[1,2,3])
    distance = single_dist + pair_dist
    print distance.get_shape()

    min_loc = tf.argmin(distance,1)
    print min_loc.get_shape()

    if include_counts:
        y, _, count = tf.unique_with_counts(min_loc)
        return tf.gather(dictionary, min_loc), [y, count]
    else:
        return tf.gather(dictionary, min_loc)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号