def create_data(self):
with tf.device("/cpu:0"):
with tf.variable_scope('input') as scope:
self.print_ext('Creating training Tensorflow Tensors')
vertices = self.graph_vertices[:, self.train_idx, :]
adjacency = self.graph_adjacency[:, self.train_idx, :, :]
adjacency = adjacency[:, :, :, self.train_idx]
labels = self.graph_labels[:, self.train_idx]
input_mask = np.ones([1, len(self.train_idx), 1]).astype(np.float32)
train_input = [vertices, adjacency, labels, input_mask]
train_input = self.create_input_variable(train_input)
vertices = self.graph_vertices
adjacency = self.graph_adjacency
labels = self.graph_labels
input_mask = np.zeros([1, self.largest_graph, 1]).astype(np.float32)
input_mask[:, self.test_idx, :] = 1
test_input = [vertices, adjacency, labels, input_mask]
test_input = self.create_input_variable(test_input)
return tf.cond(self.net.is_training, lambda: train_input, lambda: test_input)
评论列表
文章目录