reset_layers.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:conv_qsar_fast 作者: connorcoley 项目源码 文件源码
def reset(model):
    '''Given a Keras model consisting only of GraphFP, Dense, and Dropout layers,
    this function will reset the trainable weights to save time for CV tests.'''

    for layer in model.layers:
        # Note: these are custom depending on the layer type
        if '.GraphFP' in str(layer):
            W_inner = layer.init_inner((layer.inner_dim, layer.inner_dim))
            b_inner = np.zeros((1, layer.inner_dim))
            # Inner weights
            layer.W_inner.set_value((T.tile(W_inner, (layer.depth + 1, 1, 1)).eval() + \
                initializers.uniform((layer.depth + 1, layer.inner_dim, layer.inner_dim)).eval()).astype(np.float32))
            layer.b_inner.set_value((T.tile(b_inner, (layer.depth + 1, 1, 1)).eval()  + \
                initializers.uniform((layer.depth + 1, 1, layer.inner_dim)).eval()).astype(np.float32))

            # Outer weights
            W_output = layer.init_output((layer.inner_dim, layer.output_dim), scale = layer.scale_output)
            b_output = np.zeros((1, layer.output_dim))
            # Initialize weights tensor
            layer.W_output.set_value((T.tile(W_output, (layer.depth + 1, 1, 1)).eval()).astype(np.float32))
            layer.b_output.set_value((T.tile(b_output, (layer.depth + 1, 1, 1)).eval()).astype(np.float32))
            print('graphFP layer reset')

        elif '.Dense' in str(layer):
            layer.W.set_value((layer.init(layer.W.shape.eval()).eval()).astype(np.float32))
            layer.b.set_value(np.zeros(layer.b.shape.eval(), dtype=np.float32))
            print('dense layer reset')

        elif '.Dropout' in str(layer):
            print('dropout unchanged')
        else:
            print('Not reseting weights for {}'.format(str(layer)))
    print('Reset model weights')
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号