seq2seq_aligner.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def add_decoder_op(self, enc_final_state, enc_hidden_states, output_embed_matrix, training):
        cell_dec = tf.contrib.rnn.MultiRNNCell([self.make_rnn_cell(i, True) for i in range(self.config.rnn_layers)])

        encoder_hidden_size = int(enc_hidden_states.get_shape()[-1])
        decoder_hidden_size = int(cell_dec.output_size)

        # if encoder and decoder have different sizes, add a projection layer
        if encoder_hidden_size != decoder_hidden_size:
            assert False, (encoder_hidden_size, decoder_hidden_size)
            with tf.variable_scope('hidden_projection'):
                kernel = tf.get_variable('kernel', (encoder_hidden_size, decoder_hidden_size), dtype=tf.float32)

                # apply a relu to the projection for good measure
                enc_final_state = nest.map_structure(lambda x: tf.nn.relu(tf.matmul(x, kernel)), enc_final_state)
                enc_hidden_states = tf.nn.relu(tf.tensordot(enc_hidden_states, kernel, [[2], [1]]))
        else:
            # flatten and repack the state
            enc_final_state = nest.pack_sequence_as(cell_dec.state_size, nest.flatten(enc_final_state))

        if self.config.connect_output_decoder:
            cell_dec = ParentFeedingCellWrapper(cell_dec, enc_final_state)
        else:
            cell_dec = InputIgnoringCellWrapper(cell_dec, enc_final_state)
        if self.config.apply_attention:
            attention = LuongAttention(self.config.decoder_hidden_size, enc_hidden_states, self.input_length_placeholder,
                                       probability_fn=tf.nn.softmax)
            cell_dec = AttentionWrapper(cell_dec, attention,
                                        cell_input_fn=lambda inputs, _: inputs,
                                        attention_layer_size=self.config.decoder_hidden_size,
                                        initial_cell_state=enc_final_state)
            enc_final_state = cell_dec.zero_state(self.batch_size, dtype=tf.float32)
        decoder = Seq2SeqDecoder(self.config, self.input_placeholder, self.input_length_placeholder,
                                 self.output_placeholder, self.output_length_placeholder, self.batch_number_placeholder)
        return decoder.decode(cell_dec, enc_final_state, self.config.grammar.output_size, output_embed_matrix, training)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号