base_aligner.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def build(self):
        self.add_placeholders()


        xavier = tf.contrib.layers.xavier_initializer(seed=1234)
        inputs, output_embed_matrix = self.add_input_op(xavier)

        # the encoder
        with tf.variable_scope('RNNEnc', initializer=xavier):
            enc_hidden_states, enc_final_state = self.add_encoder_op(inputs=inputs)
        self.final_encoder_state = enc_final_state

        # the training decoder
        with tf.variable_scope('RNNDec', initializer=xavier):
            train_preds = self.add_decoder_op(enc_final_state=enc_final_state, enc_hidden_states=enc_hidden_states, output_embed_matrix=output_embed_matrix, training=True)
        self.loss = self.add_loss_op(train_preds) + self.add_regularization_loss()
        self.train_op = self.add_training_op(self.loss)

        # the inference decoder
        with tf.variable_scope('RNNDec', initializer=xavier, reuse=True):
            eval_preds = self.add_decoder_op(enc_final_state=enc_final_state, enc_hidden_states=enc_hidden_states, output_embed_matrix=output_embed_matrix, training=False)
        self.pred = self.finalize_predictions(eval_preds)
        self.eval_loss = self.add_loss_op(eval_preds)

        weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        size = 0
        def get_size(w):
            shape = w.get_shape()
            if shape.ndims == 2:
                return int(shape[0])*int(shape[1])
            else:
                assert shape.ndims == 1
                return int(shape[0])
        for w in weights:
            sz = get_size(w)
            print('weight', w, sz)
            size += sz
        print('total model size', size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号