def _apply_prune_on_grads(self,
grads_and_vars: list,
threshold: float):
# we need to make gradients correspondent
# to the pruned weights to be zero
grads_and_vars_sparse = []
for grad, var in grads_and_vars:
if 'weights' in var.name:
small_weights = tf.greater(threshold, tf.abs(var))
mask = tf.cast(tf.logical_not(small_weights), tf.float32)
grad = grad * mask
grads_and_vars_sparse.append((grad, var))
return grads_and_vars_sparse
network_dense.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录