doc2vec_train_doc_prediction.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def model(self,
              input_vectors, input_gene, input_variation, output_label, batch_size,
              embedding_size=EMBEDDINGS_SIZE,
              output_classes=9,
              learning_rate_initial=D2V_DOC_LEARNING_RATE_INITIAL,
              learning_rate_decay=D2V_DOC_LEARNING_RATE_DECAY,
              learning_rate_decay_steps=D2V_DOC_LEARNING_RATE_DECAY_STEPS):
        self.global_step = training_util.get_or_create_global_step()

        logits, targets = doc2vec_prediction_model(input_vectors, input_gene, input_variation,
                                                   output_label, batch_size,
                                                   is_training=True, embedding_size=embedding_size,
                                                   output_classes=output_classes)

        self.prediction = tf.nn.softmax(logits)

        self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=targets, logits=logits)
        self.loss = tf.reduce_mean(self.loss)
        tf.summary.scalar('loss', self.loss)

        # learning rate & optimizer
        self.learning_rate = tf.train.exponential_decay(learning_rate_initial, self.global_step,
                                                        learning_rate_decay_steps,
                                                        learning_rate_decay,
                                                        staircase=True, name='learning_rate')
        tf.summary.scalar('learning_rate', self.learning_rate)
        sgd = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.optimizer = sgd.minimize(self.loss, global_step=self.global_step)

        # metrics
        self.metrics = metrics.single_label(self.prediction, targets)

        # saver to save the model
        self.saver = tf.train.Saver()
        # check a nan value in the loss
        self.loss = tf.check_numerics(self.loss, 'loss is nan')

        return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号