learners.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:auDeep 作者: auDeep 项目源码 文件源码
def fit(self,
            features: np.ndarray,
            labels: np.ndarray,
            quiet=False):
        # generic parameter checks
        super().fit(features, labels)

        self._num_labels = len(np.unique(labels))

        graph = tf.Graph()

        with graph.as_default():
            tf_inputs = tf.Variable(initial_value=features, trainable=False, dtype=tf.float32)
            tf_labels = tf.Variable(initial_value=labels, trainable=False, dtype=tf.int32)

            if self._shuffle_training:
                tf_inputs = tf.random_shuffle(tf_inputs, seed=42)
                tf_labels = tf.random_shuffle(tf_labels, seed=42)

            with tf.variable_scope("mlp"):
                tf_logits = self._model.inference(tf_inputs, self._keep_prob, self._num_labels)
                tf_loss = self._model.loss(tf_logits, tf_labels)
                tf_train_op = self._model.optimize(tf_loss, self._learning_rate)

            tf_init_op = tf.global_variables_initializer()
            tf_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="mlp"))

        session = tf.Session(graph=graph)
        session.run(tf_init_op)

        for epoch in range(self._num_epochs):
            session.run(tf_train_op)

        # timestamped model file
        self._latest_checkpoint = self._checkpoint_dir / "model_{:%Y%m%d%H%M%S%f}".format(datetime.datetime.now())
        tf_saver.save(session, str(self._latest_checkpoint), write_meta_graph=False)

        session.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号