def minibatch_discrimination(input_layer, num_kernels, dim_per_kernel=5, name='minibatch_discrim'):
# batch_size = input_layer.shape[0]
# num_features = input_layer.shape[1]
batch_size = input_layer.get_shape().as_list()[0]
num_features = input_layer.get_shape().as_list()[1]
W = tf.get_variable('W', [num_features, num_kernels * dim_per_kernel],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable('b', [num_kernels], initializer=tf.constant_initializer(0.0))
activation = tf.matmul(input_layer, W)
activation = tf.reshape(activation, [batch_size, num_kernels, dim_per_kernel])
tmp1 = tf.expand_dims(activation, 3)
tmp2 = tf.transpose(activation, perm=[1, 2, 0])
tmp2 = tf.expand_dims(tmp2, 0)
abs_diff = tf.reduce_sum(tf.abs(tmp1 - tmp2), reduction_indices=[2])
f = tf.reduce_sum(tf.exp(-abs_diff), reduction_indices=[2])
f = f + b
return f
ops.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录