LookupConvolution2d.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tf-lcnn 作者: ildoonet 项目源码 文件源码
def extract_dense_weights(sess):
    for key in dense_layers.keys():
        layer = dense_layers[key]

        # sparse kernel
        dense_kernel = layer.kernel
        dense_kernel_shape = dense_kernel.get_shape().as_list()
        # dense_kernel = tf.reshape(dense_kernel, [dense_kernel_shape[0] * dense_kernel_shape[1] * dense_kernel_shape[2],
        #                                          dense_kernel_shape[3]])
        # dense_kernel = tf.transpose(dense_kernel)
        idx = tf.where(tf.not_equal(dense_kernel, 0))
        sparse_kernel = tf.SparseTensor(idx, tf.gather_nd(dense_kernel, idx), dense_kernel.get_shape())

        if layer.bias is not None:
            dk, k, b = sess.run([dense_kernel, sparse_kernel, layer.bias])
        else:
            dk, k = sess.run([dense_kernel, sparse_kernel])
            b = None
        dense_weights['%s/%s' % (key, 'kernel_dense')] = dk
        dense_weights['%s/%s' % (key, 'kernel')] = k
        dense_weights['%s/%s' % (key, 'kernel_shape')] = dense_kernel_shape
        dense_weights['%s/%s' % (key, 'bias')] = b
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号