def model(self, input_text_begin, input_text_end, gene, variation, expected_labels, batch_size,
vocabulary_size=VOCABULARY_SIZE, embeddings_size=EMBEDDINGS_SIZE, output_classes=9):
# embeddings
embeddings = _load_embeddings(vocabulary_size, embeddings_size)
# model
with slim.arg_scope(self.text_classification_model.model_arg_scope()):
outputs = self.text_classification_model.model(input_text_begin, input_text_end,
gene, variation, output_classes,
embeddings=embeddings,
batch_size=batch_size,
training=False)
# loss
targets = self.text_classification_model.targets(expected_labels, output_classes)
loss = self.text_classification_model.loss(targets, outputs)
self.accumulated_loss = tf.Variable(0.0, dtype=tf.float32, name='accumulated_loss',
trainable=False)
self.accumulated_loss = tf.assign_add(self.accumulated_loss, loss)
step = tf.Variable(0, dtype=tf.int32, name='eval_step', trainable=False)
step_increase = tf.assign_add(step, 1)
self.loss = self.accumulated_loss / tf.cast(step_increase, dtype=tf.float32)
tf.summary.scalar('loss', self.loss)
# metrics
self.metrics = metrics.single_label(outputs['prediction'], targets, moving_average=False)
return None
text_classification_train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录