bridges_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def setUp(self):
    super(BridgeTest, self).setUp()
    self.batch_size = 4
    self.encoder_cell = tf.contrib.rnn.MultiRNNCell(
        [tf.contrib.rnn.GRUCell(4), tf.contrib.rnn.GRUCell(8)])
    self.decoder_cell = tf.contrib.rnn.MultiRNNCell(
        [tf.contrib.rnn.LSTMCell(16), tf.contrib.rnn.GRUCell(8)])
    final_encoder_state = nest.map_structure(
        lambda x: tf.convert_to_tensor(
            value=np.random.randn(self.batch_size, x),
            dtype=tf.float32),
        self.encoder_cell.state_size)
    self.encoder_outputs = EncoderOutput(
        outputs=tf.convert_to_tensor(
            value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32),
        attention_values=tf.convert_to_tensor(
            value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32),
        attention_values_length=np.full([self.batch_size], 10),
        final_state=final_encoder_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号