test_network_dense.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pruning_with_tensorflow 作者: ex4sperans 项目源码 文件源码
def test_shapes(self):

        input_size = 20
        n_classes = 5
        layer_sizes = [5, 10]

        network = network_dense.FullyConnectedClassifier(input_size=input_size,
                                                         n_classes=n_classes,
                                                         layer_sizes=layer_sizes,
                                                         model_path='temp',
                                                         verbose=False)

        self.assertEqual(network.logits.get_shape().as_list(), [None, 5])
        self.assertEqual(network.loss.get_shape().as_list(), [])
        self.assertIsInstance(network.train_op, tf.Operation)

        shapes = [[20, 5], [5, 10], [10, 5]]
        for v, shape in zip(network.weight_matrices, shapes):
            self.assertEqual(v.get_shape().as_list(), shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号