def match_to_dict_conv(image_as_patches, dictionary, include_counts=False):
print 'match_to_dict_conv'
[n,w,h,c] = dictionary.get_shape().as_list()
#dict_as_filt = tf.transpose(tf.reshape(dictionary, [-1, w*h*c,1,1]))
dict_as_filt = tf.transpose(tf.reshape(dictionary, [-1, w*h*c]))
print dict_as_filt.get_shape()
[n,w,h,c] = image_as_patches.get_shape().as_list()
#image_flattened = tf.reshape(image_as_patches, [-1,1,1,w*h*c])
image_flattened = tf.reshape(image_as_patches, [-1,w*h*c])
print image_flattened.get_shape()
#pair_dist = -2 * tf.reshape(tf.nn.conv2d(image_flattened, dict_as_filt, [1,1,1,1], 'SAME'), [n, -1])
pair_dist = -2 * tf.matmul(image_flattened, dict_as_filt)
print pair_dist.get_shape()
single_dist = tf.reduce_sum(tf.square(dictionary),[1,2,3])
distance = single_dist + pair_dist
print distance.get_shape()
min_loc = tf.argmin(distance,1)
print min_loc.get_shape()
if include_counts:
y, _, count = tf.unique_with_counts(min_loc)
return tf.gather(dictionary, min_loc), [y, count]
else:
return tf.gather(dictionary, min_loc)
评论列表
文章目录