text_classification_train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def model(self, input_text_begin, input_text_end, gene, variation, batch_size,
              vocabulary_size=VOCABULARY_SIZE, embeddings_size=EMBEDDINGS_SIZE, output_classes=9):
        # embeddings
        embeddings = _load_embeddings(vocabulary_size, embeddings_size)
        # global step
        self.global_step = training_util.get_or_create_global_step()
        self.global_step = tf.assign_add(self.global_step, 1)
        # model
        with tf.control_dependencies([self.global_step]):
            with slim.arg_scope(self.text_classification_model.model_arg_scope()):
                self.outputs = self.text_classification_model.model(input_text_begin, input_text_end,
                                                                    gene, variation, output_classes,
                                                                    embeddings=embeddings,
                                                                    batch_size=batch_size,
                                                                    training=False)
        # restore only the trainable variables
        self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
        return self.outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号