def setUp(self):
super(BridgeTest, self).setUp()
self.batch_size = 4
self.encoder_cell = tf.contrib.rnn.MultiRNNCell(
[tf.contrib.rnn.GRUCell(4), tf.contrib.rnn.GRUCell(8)])
self.decoder_cell = tf.contrib.rnn.MultiRNNCell(
[tf.contrib.rnn.LSTMCell(16), tf.contrib.rnn.GRUCell(8)])
final_encoder_state = nest.map_structure(
lambda x: tf.convert_to_tensor(
value=np.random.randn(self.batch_size, x),
dtype=tf.float32),
self.encoder_cell.state_size)
self.encoder_outputs = EncoderOutput(
outputs=tf.convert_to_tensor(
value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32),
attention_values=tf.convert_to_tensor(
value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32),
attention_values_length=np.full([self.batch_size], 10),
final_state=final_encoder_state)
评论列表
文章目录