ops.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Unsupervised-Anomaly-Detection-with-Generative-Adversarial-Networks 作者: xtarx 项目源码 文件源码
def minibatch_discrimination(input_layer, num_kernels, dim_per_kernel=5, name='minibatch_discrim'):
    # batch_size = input_layer.shape[0]
    # num_features = input_layer.shape[1]
    batch_size = input_layer.get_shape().as_list()[0]
    num_features = input_layer.get_shape().as_list()[1]
    W = tf.get_variable('W', [num_features, num_kernels * dim_per_kernel],
                        initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable('b', [num_kernels], initializer=tf.constant_initializer(0.0))
    activation = tf.matmul(input_layer, W)
    activation = tf.reshape(activation, [batch_size, num_kernels, dim_per_kernel])
    tmp1 = tf.expand_dims(activation, 3)
    tmp2 = tf.transpose(activation, perm=[1, 2, 0])
    tmp2 = tf.expand_dims(tmp2, 0)
    abs_diff = tf.reduce_sum(tf.abs(tmp1 - tmp2), reduction_indices=[2])
    f = tf.reduce_sum(tf.exp(-abs_diff), reduction_indices=[2])
    f = f + b
    return f
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号