def block_truncate_conv(V,mu,rho):
coef = 0.5
V_shape = tf.shape(V)
b = tf.sqrt(tf.div(tf.mul(2.,mu),rho)) #threshold
# Reshape the 4D tensor of weights to a 2D matrix with rows containing the conv filters in vectorized form.
V_shape1 = tf.concat(0,[tf.mul(tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1])),tf.mul(tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1]))])
V = tf.reshape(tf.transpose(V,perm=[2,3,0,1]),V_shape1)
norm_V = frobenius_norm_block(V,1)
norm_V_per_dimension = tf.div(norm_V,tf.cast(tf.slice(V_shape1,[1],[1]),'float'))
# Implementation of Eq.10 in the paper using if condition inside the TensorFlow graph with tf.cond
zero_part = tf.zeros(V_shape1)
zero_ind = tf.greater_equal(b,norm_V_per_dimension)
num_zero = tf.reduce_sum(tf.cast(zero_ind,'float'))
# You can pass parameters to the functions in tf.cond() using lambda
f4 = lambda: tf.greater_equal(tf.reduce_mean(norm_V),norm_V)
f5 = lambda: zero_ind
zero_ind = tf.cond(tf.greater(num_zero,tf.mul(coef,tf.cast(V_shape1[0],'float'))),f4,f5)
G = tf.select(zero_ind,zero_part,V)
G_shape = tf.concat(0,[tf.slice(V_shape,[2],[1]),tf.slice(V_shape,[3],[1]),tf.slice(V_shape,[0],[1]),tf.slice(V_shape,[1],[1])])
G = tf.transpose(tf.reshape(G,G_shape),perm=[2,3,0,1])
return G,zero_ind
评论列表
文章目录