transducer_model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:segmenter 作者: yanshao9798 项目源码 文件源码
def define(self, char_num, rnn_dim, emb_dim, max_x, max_y, write_trans_model=True):
        self.decode_step = max_y
        self.encode_step = max_x
        self.en_vec = [tf.placeholder(tf.int32, [None], name='en_input' + str(i)) for i in range(max_x)]
        self.trans_labels = [tf.placeholder(tf.int32, [None], name='de_input' + str(i)) for i in range(max_y)]
        weights = [tf.cast(tf.sign(ot_t), tf.float32) for ot_t in self.trans_labels]
        self.de_vec = [tf.zeros_like(self.trans_labels[0], tf.int32)] + self.trans_labels[:-1]
        self.feed_previous = tf.placeholder(tf.bool)
        self.trans_l_rate = tf.placeholder(tf.float32, [], name='learning_rate')
        seq_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, state_is_tuple=True)
        self.trans_output, states = seq2seq.embedding_attention_seq2seq(self.en_vec, self.de_vec, seq_cell, char_num,
                                                                        char_num, emb_dim, feed_previous=self.feed_previous)

        loss = seq2seq.sequence_loss(self.trans_output, self.trans_labels, weights)
        optimizer = tf.train.AdagradOptimizer(learning_rate=self.trans_l_rate)

        params = tf.trainable_variables()
        gradients = tf.gradients(loss, params)
        clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5.0)
        self.trans_train = optimizer.apply_gradients(zip(clipped_gradients, params))

        self.saver = tf.train.Saver()

        if write_trans_model:
            param_dic = {}
            param_dic['char_num'] = char_num
            param_dic['rnn_dim'] = rnn_dim
            param_dic['emb_dim'] = emb_dim
            param_dic['max_x'] = max_x
            param_dic['max_y'] = max_y
            # print param_dic
            f_model = open(self.trained + '_model', 'w')
            pickle.dump(param_dic, f_model)
            f_model.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号