base_aligner.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def add_input_op(self, xavier):
        with tf.variable_scope('embed'):
            # first the embed the input
            if self.config.train_input_embeddings:
                if self.config.input_embedding_matrix:
                    initializer = tf.constant_initializer(self.config.input_embedding_matrix)
                else:
                    initializer = xavier
                input_embed_matrix = tf.get_variable('input_embedding',
                                                     shape=(self.config.dictionary_size, self.config.embed_size),
                                                     initializer=initializer)
            else:
                input_embed_matrix = tf.constant(self.config.input_embedding_matrix)

            # dictionary size x embed_size
            assert input_embed_matrix.get_shape() == (self.config.dictionary_size, self.config.embed_size)

            # now embed the output
            if self.config.train_output_embeddings:
                output_embed_matrix = tf.get_variable('output_embedding',
                                                      shape=(self.config.output_size, self.config.output_embed_size),
                                                      initializer=xavier)
            else:
                output_embed_matrix = tf.constant(self.config.output_embedding_matrix)

            assert output_embed_matrix.get_shape() == (self.config.output_size, self.config.output_embed_size)

        inputs = tf.nn.embedding_lookup([input_embed_matrix], self.input_placeholder)
        # batch size x max length x embed_size
        assert inputs.get_shape()[1:] == (self.config.max_length, self.config.embed_size)
        return inputs, output_embed_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号