def extract_dense_weights(sess):
for key in dense_layers.keys():
layer = dense_layers[key]
# sparse kernel
dense_kernel = layer.kernel
dense_kernel_shape = dense_kernel.get_shape().as_list()
# dense_kernel = tf.reshape(dense_kernel, [dense_kernel_shape[0] * dense_kernel_shape[1] * dense_kernel_shape[2],
# dense_kernel_shape[3]])
# dense_kernel = tf.transpose(dense_kernel)
idx = tf.where(tf.not_equal(dense_kernel, 0))
sparse_kernel = tf.SparseTensor(idx, tf.gather_nd(dense_kernel, idx), dense_kernel.get_shape())
if layer.bias is not None:
dk, k, b = sess.run([dense_kernel, sparse_kernel, layer.bias])
else:
dk, k = sess.run([dense_kernel, sparse_kernel])
b = None
dense_weights['%s/%s' % (key, 'kernel_dense')] = dk
dense_weights['%s/%s' % (key, 'kernel')] = k
dense_weights['%s/%s' % (key, 'kernel_shape')] = dense_kernel_shape
dense_weights['%s/%s' % (key, 'bias')] = b
评论列表
文章目录