ADMMutils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:sparsecnn 作者: fkiaee 项目源码 文件源码
def block_shrinkage_conv(V,mu,rho):
    coef = 0.5
    V_shape = tf.shape(V); one_val = tf.constant(1.0) 
    b = tf.div(mu,rho)
    V_shape1 = tf.concat(0,[tf.mul(tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1])),tf.mul(tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1]))])
    V = tf.reshape(tf.transpose(V,perm=[2,3,0,1]),V_shape1)
    norm_V = frobenius_norm_block(V,1)  
    norm_V_per_dimension = tf.div(norm_V,tf.cast(tf.slice(V_shape1,[1],[1]),'float'))
    zero_part = tf.zeros(V_shape1)
    zero_ind = tf.greater_equal(b,norm_V_per_dimension)
    num_zero = tf.reduce_sum(tf.cast(zero_ind,'float'))
#    f4 = lambda: tf.greater_equal(tf.truediv(tf.add(tf.reduce_min(fro),tf.reduce_mean(fro)),2.0),fro)
    f4 = lambda: tf.greater_equal(tf.reduce_mean(norm_V),norm_V)
    f5 = lambda: zero_ind
    zero_ind = tf.cond(tf.greater(num_zero,tf.mul(coef,tf.cast(V_shape1[0],'float'))),f4,f5)
    G = tf.select(zero_ind,zero_part,tf.mul(tf.sub(one_val,tf.div(b,tf.reshape(norm_V,[-1,1]))),V)) 
    G_shape = tf.concat(0,[tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1]),tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1])])
    G = tf.transpose(tf.reshape(G,G_shape),perm=[2,3,0,1])
    return G,zero_ind
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号