def build_model(self):
self.model = classmap[FLAGS.model_type](hidden_size=FLAGS.hidden,
vocab_size=self.vocab_size,
encoder_in_size=self.data.feats.shape[-1],
encoder_in_length=self.data.feats.shape[1],
decoder_in_length=self.data.decoder_in.shape[-1] - 1,
word2vec_weight=self.w2v_W,
embedding_size=FLAGS.embedding_dim,
neg_sample_num=self.sample_num,
start_id=self.vocab_processor._mapping['<BOS>'],
end_id=self.vocab_processor._mapping['<EOS>'],
Bk=FLAGS.K)
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.optimizer = tf.train.RMSPropOptimizer(FLAGS.lr)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.model.cost, tvars), 5)
self.updates = self.optimizer.apply_gradients(
zip(grads, tvars), global_step=self.global_step)
self.saver = tf.train.Saver(tf.global_variables())
评论列表
文章目录