def model(self,
input_vectors, input_gene, input_variation, output_label, batch_size,
embedding_size=EMBEDDINGS_SIZE,
output_classes=9,
learning_rate_initial=D2V_DOC_LEARNING_RATE_INITIAL,
learning_rate_decay=D2V_DOC_LEARNING_RATE_DECAY,
learning_rate_decay_steps=D2V_DOC_LEARNING_RATE_DECAY_STEPS):
self.global_step = training_util.get_or_create_global_step()
logits, targets = doc2vec_prediction_model(input_vectors, input_gene, input_variation,
output_label, batch_size,
is_training=True, embedding_size=embedding_size,
output_classes=output_classes)
self.prediction = tf.nn.softmax(logits)
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=targets, logits=logits)
self.loss = tf.reduce_mean(self.loss)
tf.summary.scalar('loss', self.loss)
# learning rate & optimizer
self.learning_rate = tf.train.exponential_decay(learning_rate_initial, self.global_step,
learning_rate_decay_steps,
learning_rate_decay,
staircase=True, name='learning_rate')
tf.summary.scalar('learning_rate', self.learning_rate)
sgd = tf.train.GradientDescentOptimizer(self.learning_rate)
self.optimizer = sgd.minimize(self.loss, global_step=self.global_step)
# metrics
self.metrics = metrics.single_label(self.prediction, targets)
# saver to save the model
self.saver = tf.train.Saver()
# check a nan value in the loss
self.loss = tf.check_numerics(self.loss, 'loss is nan')
return None
doc2vec_train_doc_prediction.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录