def add_minibatch_features(image,df_dim):
shape = image.get_shape().as_list()
dim = np.prod(shape[1:]) # dim = prod(9,2) = 18
h_mb0 = lrelu(conv2d(image, df_dim, name='d_mb0_conv'))
h_mb1 = conv2d(h_mb0, df_dim, name='d_mbh1_conv')
dims=h_mb1.get_shape().as_list()
conv_dims=np.prod(dims[1:])
image_ = tf.reshape(h_mb1, tf.stack([-1, conv_dims]))
#image_ = tf.reshape(h_mb1, tf.stack([batch_size, -1]))
n_kernels = 300
dim_per_kernel = 50
x = linear(image_, n_kernels * dim_per_kernel,'d_mbLinear')
act = tf.reshape(x, (-1, n_kernels, dim_per_kernel))
act= tf.reshape(x, (-1, n_kernels, dim_per_kernel))
act_tp=tf.transpose(act, [1,2,0])
#bs x n_ker x dim_ker x bs -> bs x n_ker x bs :
abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(act, 3) - tf.expand_dims(act_tp, 0)), 2)
eye=tf.expand_dims( tf.eye( tf.shape(abs_dif)[0] ), 1)#bs x 1 x bs
masked=tf.exp(-abs_dif) - eye
f1=tf.reduce_mean( masked, 2)
mb_features = tf.reshape(f1, [-1, 1, 1, n_kernels])
return conv_cond_concat(image, mb_features)
## following is from https://github.com/openai/improved-gan/blob/master/imagenet/discriminator.py#L88
#def add_minibatch_features(image,df_dim,batch_size):
# shape = image.get_shape().as_list()
# dim = np.prod(shape[1:]) # dim = prod(9,2) = 18
# h_mb0 = lrelu(conv2d(image, df_dim, name='d_mb0_conv'))
# h_mb1 = conv2d(h_mb0, df_dim, name='d_mbh1_conv')
#
# dims=h_mb1.get_shape().as_list()
# conv_dims=np.prod(dims[1:])
#
# image_ = tf.reshape(h_mb1, tf.stack([-1, conv_dims]))
# #image_ = tf.reshape(h_mb1, tf.stack([batch_size, -1]))
#
# n_kernels = 300
# dim_per_kernel = 50
# x = linear(image_, n_kernels * dim_per_kernel,'d_mbLinear')
# activation = tf.reshape(x, (batch_size, n_kernels, dim_per_kernel))
# big = np.zeros((batch_size, batch_size), dtype='float32')
# big += np.eye(batch_size)
# big = tf.expand_dims(big, 1)
# abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)), 2)
# mask = 1. - big
# masked = tf.exp(-abs_dif) * mask
# f1 = tf.reduce_sum(masked, 2) / tf.reduce_sum(mask)
# mb_features = tf.reshape(f1, [batch_size, 1, 1, n_kernels])
# return conv_cond_concat(image, mb_features)
评论列表
文章目录