def test_create_cell(self):
seq2seq = self.seq2seq
# we will use one hot encoding of the input batch, this is how it is constructed
# we will use 0 for padding so our vocabulary size will increase by one
vocab_len = len(seq2seq.vocab)
depth = vocab_len + 1
no_stacked_cells = self.no_stacked_cells
hidden_size = self.hidden_size
seq = tf.placeholder(dtype=tf.int32, shape=[None, None])
one_hot_seq = tf.one_hot(seq, depth=depth)
self.assertHasShape(one_hot_seq, [None, None, depth])
# creates cell using seq as input batch placeholder
cell, in_state = seq2seq._create_cell(one_hot_seq, no_stacked_cells)
self.assertIsInstance(cell, tf.contrib.rnn.MultiRNNCell)
self.assertEqual(len(in_state), no_stacked_cells)
for state in in_state:
self.assertHasShape(state, [None, hidden_size])
# before calling __call__ on cell, internal variables are not created
# not much we can test right now
self.assertListEqual(tf.trainable_variables(), [])
评论列表
文章目录