models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow-fcwta 作者: guoguo12 项目源码 文件源码
def _initialize_vars(self):
        """Sets up the training graph."""
        with tf.variable_scope(self.name) as scope:
            self.global_step = tf.get_variable(
                'global_step',
                shape=[],
                initializer=tf.zeros_initializer())
            self.input = tf.placeholder(tf.float32, shape=[None, self.input_dim])

        current = self.input
        for i in range(self.encode_layers - 1):
            current = self._relu_layer(current, self.input_dim, self.input_dim, i)
        self.encoded = self._relu_layer(current, self.input_dim, self.hidden_units, self.encode_layers - 1)

        # Make batch size the last dimension (for use with tf.nn.top_k)
        encoded_t = tf.transpose(self.encoded)

        # Compute the indices corresponding to the top k activations for each
        # neuron in the final encoder layer
        k = int(self.sparsity * self.batch_size)
        _, top_indices = tf.nn.top_k(encoded_t, k=k, sorted=False)

        # Transform top_indices, which contains rows of column indices, into
        # indices, a list of [row, column] pairs (for use with tf.scatter_nd)
        top_k_unstacked = tf.unstack(top_indices, axis=1)
        row_indices = [tf.range(self.hidden_units) for _ in range(k)]
        combined_columns = tf.transpose(tf.stack(_interleave(row_indices, top_k_unstacked)))
        indices = tf.reshape(combined_columns, [-1, 2])

        # Apply sparsity constraint
        updates = tf.ones(self.hidden_units * k)
        shape = tf.constant([self.hidden_units, self.batch_size])
        mask = tf.scatter_nd(indices, updates, shape)
        sparse_encoded = self.encoded * tf.transpose(mask)

        self.decoded = self._decode_layer(sparse_encoded)

        self.loss = tf.reduce_sum(tf.square(self.decoded - self.input))
        self.optimizer_op = self.optimizer(self.learning_rate).minimize(
            self.loss, self.global_step)

        self.saver = tf.train.Saver(tf.global_variables())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号