nnet_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def test_seq(self):
        X = K.placeholder((None, 28, 28, 1))
        f = N.Sequence([
            N.Conv(8, (3, 3), strides=1, pad='same'),
            N.Dimshuffle(pattern=(0, 3, 1, 2)),
            N.Flatten(outdim=2),
            N.Noise(level=0.3, noise_dims=None, noise_type='gaussian'),
            N.Dense(128, activation=tf.nn.relu),
            N.Dropout(level=0.3, noise_dims=None),
            N.Dense(10, activation=tf.nn.softmax)
        ])
        y = f(X)
        yT = f.T(y)
        f1 = K.function(X, y, defaults={K.is_training(): True})
        f2 = K.function(X, yT, defaults={K.is_training(): False})

        f = cPickle.loads(cPickle.dumps(f))
        y = f(X)
        yT = f.T(y)
        f3 = K.function(X, y, defaults={K.is_training(): True})
        f4 = K.function(X, yT, defaults={K.is_training(): False})

        x = np.random.rand(12, 28, 28, 1)

        self.assertEquals(f1(x).shape, (2688, 10))
        self.assertEquals(f3(x).shape, (2688, 10))
        self.assertEqual(np.round(f1(x).sum(), 4),
                         np.round(f3(x).sum(), 4))
        self.assertEquals(y.get_shape().as_list(), (None, 10))

        self.assertEquals(f2(x).shape, (12, 28, 28, 1))
        self.assertEquals(f4(x).shape, (12, 28, 28, 1))
        self.assertEqual(str(f2(x).sum())[:4], str(f4(x).sum())[:4])
        self.assertEquals(yT.get_shape().as_list(), (None, 28, 28, 1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号