ADMMutils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sparsecnn 作者: fkiaee 项目源码 文件源码
def block_truncate_conv(V,mu,rho):
    coef = 0.5
    V_shape = tf.shape(V) 
    b = tf.sqrt(tf.div(tf.mul(2.,mu),rho)) #threshold 
    # Reshape the 4D tensor of weights to a 2D matrix with rows containing the conv filters in vectorized form.
    V_shape1 = tf.concat(0,[tf.mul(tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1])),tf.mul(tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1]))])
    V = tf.reshape(tf.transpose(V,perm=[2,3,0,1]),V_shape1)
    norm_V = frobenius_norm_block(V,1)  
    norm_V_per_dimension = tf.div(norm_V,tf.cast(tf.slice(V_shape1,[1],[1]),'float'))
    # Implementation of Eq.10 in the paper using if condition inside the TensorFlow graph with tf.cond
    zero_part = tf.zeros(V_shape1)
    zero_ind = tf.greater_equal(b,norm_V_per_dimension)
    num_zero = tf.reduce_sum(tf.cast(zero_ind,'float'))
    # You can pass parameters to the functions in tf.cond() using lambda
    f4 = lambda: tf.greater_equal(tf.reduce_mean(norm_V),norm_V)
    f5 = lambda: zero_ind
    zero_ind = tf.cond(tf.greater(num_zero,tf.mul(coef,tf.cast(V_shape1[0],'float'))),f4,f5)
    G = tf.select(zero_ind,zero_part,V) 
    G_shape = tf.concat(0,[tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1]),tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1])])
    G = tf.transpose(tf.reshape(G,G_shape),perm=[2,3,0,1])
    return G,zero_ind
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号