neural_fingerprints.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:neural_fingerprints_tf 作者: fllinares 项目源码 文件源码
def graph_convolution_layer(self, node_emb, scope, edge_emb=None):
        # Path to hyperparameters and configuration settings for the graph convolutional layers
        prefix = 'model/graph_conv_layers'

        with tf.variable_scope(scope, reuse=not self.is_training):
            # Compute the extended node embedding as the concatenation of the original node embedding and the sum of
            # the node embeddings of all distance-one neighbors in the graph.
            ext_node_emb = tf.concat([node_emb, tf.sparse_tensor_dense_matmul(self.input['adj_mat'], node_emb)], axis=1)
            # If edge labels are to be considered by the model, concatenate as well the (pre-computed) sum of the
            # feature vectors labelling all edges connected to each node
            if edge_emb is not None:
                ext_node_emb = tf.concat([ext_node_emb, edge_emb], axis=1)

            # Compute output by applying a fully connected layer to the extended node embedding
            out = tf.contrib.layers.fully_connected(inputs=ext_node_emb,
                                                    num_outputs=self.getitem('config', 'num_outputs', prefix),
                                                    activation_fn=self.string_to_tf_act(self.getitem('config', 'activation_fn', prefix)),
                                                    weights_initializer=self.weights_initializer_graph_conv,
                                                    weights_regularizer=self.weights_regularizer_graph_conv,
                                                    biases_initializer=tf.constant_initializer(0.1, tf.float32),
                                                    normalizer_fn=self.normalizer_fn_graph_conv,
                                                    normalizer_params=self.normalizer_params_graph_conv,
                                                    trainable=self.getitem('config', 'trainable', prefix))

            # Apply dropout (if necessary). Alternatively, could have also forced keep_prob to 1.0 when is_training is
            # False
            if self.is_training:
                out = tf.nn.dropout(out, self.getitem('config', 'keep_prob', prefix))

        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号