def model(self, input_text_begin, input_text_end, gene, variation, expected_labels, batch_size,
vocabulary_size=VOCABULARY_SIZE, embeddings_size=EMBEDDINGS_SIZE, output_classes=9):
# embeddings
embeddings = _load_embeddings(vocabulary_size, embeddings_size)
# global step
self.global_step = training_util.get_or_create_global_step()
# model
with slim.arg_scope(self.text_classification_model.model_arg_scope()):
outputs = self.text_classification_model.model(input_text_begin, input_text_end,
gene, variation, output_classes,
embeddings=embeddings,
batch_size=batch_size)
# loss
targets = self.text_classification_model.targets(expected_labels, output_classes)
self.loss = self.text_classification_model.loss(targets, outputs)
tf.summary.scalar('loss', self.loss)
# learning rate
self.optimizer, self.learning_rate = \
self.text_classification_model.optimize(self.loss, self.global_step)
if self.learning_rate is not None:
tf.summary.scalar('learning_rate', self.learning_rate)
# metrics
self.metrics = metrics.single_label(outputs['prediction'], targets)
# saver to save the model
self.saver = tf.train.Saver()
# check a nan value in the loss
self.loss = tf.check_numerics(self.loss, 'loss is nan')
return None
text_classification_train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录