graph.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:keras-gcn 作者: tkipf 项目源码 文件源码
def __init__(self, output_dim, support=1, init='glorot_uniform',
                 activation='linear', weights=None, W_regularizer=None,
                 b_regularizer=None, bias=False, **kwargs):
        self.init = initializers.get(init)
        self.activation = activations.get(activation)
        self.output_dim = output_dim  # number of features per node
        self.support = support  # filter support / number of weights

        assert support >= 1

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.bias = bias
        self.initial_weights = weights

        # these will be defined during build()
        self.input_dim = None
        self.W = None
        self.b = None

        super(GraphConvolution, self).__init__(**kwargs)

    # def get_output_shape_for(self, input_shapes):
    #     features_shape = input_shapes[0]
    #     output_shape = (features_shape[0], self.output_dim)
    #     return output_shape  # (batch_size, output_dim)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号