test_seq2seq.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:deep_learning 作者: wecliqued 项目源码 文件源码
def test_create_cell(self):
        seq2seq = self.seq2seq

        # we will use one hot encoding of the input batch, this is how it is constructed
        # we will use 0 for padding so our vocabulary size will increase by one
        vocab_len = len(seq2seq.vocab)
        depth = vocab_len + 1
        no_stacked_cells = self.no_stacked_cells
        hidden_size = self.hidden_size

        seq = tf.placeholder(dtype=tf.int32, shape=[None, None])
        one_hot_seq = tf.one_hot(seq, depth=depth)
        self.assertHasShape(one_hot_seq, [None, None, depth])

        # creates cell using seq as input batch placeholder
        cell, in_state = seq2seq._create_cell(one_hot_seq, no_stacked_cells)
        self.assertIsInstance(cell, tf.contrib.rnn.MultiRNNCell)
        self.assertEqual(len(in_state), no_stacked_cells)
        for state in in_state:
            self.assertHasShape(state, [None, hidden_size])

        # before calling __call__ on cell, internal variables are not created
        # not much we can test right now
        self.assertListEqual(tf.trainable_variables(), [])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号