experiment.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Graph-CNN 作者: fps7806 项目源码 文件源码
def create_data(self):
        with tf.device("/cpu:0"):
            with tf.variable_scope('input') as scope:
                self.print_ext('Creating training Tensorflow Tensors')

                vertices = self.graph_vertices[:, self.train_idx, :]
                adjacency = self.graph_adjacency[:, self.train_idx, :, :]
                adjacency = adjacency[:, :, :, self.train_idx]
                labels = self.graph_labels[:, self.train_idx]
                input_mask = np.ones([1, len(self.train_idx), 1]).astype(np.float32)

                train_input = [vertices, adjacency, labels, input_mask]
                train_input = self.create_input_variable(train_input)

                vertices = self.graph_vertices
                adjacency = self.graph_adjacency
                labels = self.graph_labels

                input_mask = np.zeros([1, self.largest_graph, 1]).astype(np.float32)
                input_mask[:, self.test_idx, :] = 1
                test_input = [vertices, adjacency, labels, input_mask]
                test_input = self.create_input_variable(test_input)

                return tf.cond(self.net.is_training, lambda: train_input, lambda: test_input)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号