def define(self, char_num, rnn_dim, emb_dim, max_x, max_y, write_trans_model=True):
self.decode_step = max_y
self.encode_step = max_x
self.en_vec = [tf.placeholder(tf.int32, [None], name='en_input' + str(i)) for i in range(max_x)]
self.trans_labels = [tf.placeholder(tf.int32, [None], name='de_input' + str(i)) for i in range(max_y)]
weights = [tf.cast(tf.sign(ot_t), tf.float32) for ot_t in self.trans_labels]
self.de_vec = [tf.zeros_like(self.trans_labels[0], tf.int32)] + self.trans_labels[:-1]
self.feed_previous = tf.placeholder(tf.bool)
self.trans_l_rate = tf.placeholder(tf.float32, [], name='learning_rate')
seq_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, state_is_tuple=True)
self.trans_output, states = seq2seq.embedding_attention_seq2seq(self.en_vec, self.de_vec, seq_cell, char_num,
char_num, emb_dim, feed_previous=self.feed_previous)
loss = seq2seq.sequence_loss(self.trans_output, self.trans_labels, weights)
optimizer = tf.train.AdagradOptimizer(learning_rate=self.trans_l_rate)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5.0)
self.trans_train = optimizer.apply_gradients(zip(clipped_gradients, params))
self.saver = tf.train.Saver()
if write_trans_model:
param_dic = {}
param_dic['char_num'] = char_num
param_dic['rnn_dim'] = rnn_dim
param_dic['emb_dim'] = emb_dim
param_dic['max_x'] = max_x
param_dic['max_y'] = max_y
# print param_dic
f_model = open(self.trained + '_model', 'w')
pickle.dump(param_dic, f_model)
f_model.close()
评论列表
文章目录