text_classification_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def model(self, input_text_begin, input_text_end, gene, variation, expected_labels, batch_size,
              vocabulary_size=VOCABULARY_SIZE, embeddings_size=EMBEDDINGS_SIZE, output_classes=9):
        # embeddings
        embeddings = _load_embeddings(vocabulary_size, embeddings_size)
        # model
        with slim.arg_scope(self.text_classification_model.model_arg_scope()):
            outputs = self.text_classification_model.model(input_text_begin, input_text_end,
                                                           gene, variation, output_classes,
                                                           embeddings=embeddings,
                                                           batch_size=batch_size,
                                                           training=False)
        # loss
        targets = self.text_classification_model.targets(expected_labels, output_classes)
        loss = self.text_classification_model.loss(targets, outputs)
        self.accumulated_loss = tf.Variable(0.0, dtype=tf.float32, name='accumulated_loss',
                                            trainable=False)
        self.accumulated_loss = tf.assign_add(self.accumulated_loss, loss)
        step = tf.Variable(0, dtype=tf.int32, name='eval_step', trainable=False)
        step_increase = tf.assign_add(step, 1)
        self.loss = self.accumulated_loss / tf.cast(step_increase, dtype=tf.float32)
        tf.summary.scalar('loss', self.loss)
        # metrics
        self.metrics = metrics.single_label(outputs['prediction'], targets, moving_average=False)
        return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号