test_recurrent.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keras-customized 作者: ambrite 项目源码 文件源码
def test_regularizer(layer_class):
    layer = layer_class(output_dim, return_sequences=False, weights=None,
                        batch_input_shape=(nb_samples, timesteps, embedding_dim),
                        W_regularizer=regularizers.WeightRegularizer(l1=0.01),
                        U_regularizer=regularizers.WeightRegularizer(l1=0.01),
                        b_regularizer='l2')
    shape = (nb_samples, timesteps, embedding_dim)
    layer.build(shape)
    output = layer(K.variable(np.ones(shape)))
    K.eval(output)
    if layer_class == recurrent.SimpleRNN:
        assert len(layer.losses) == 3
    if layer_class == recurrent.GRU:
        assert len(layer.losses) == 9
    if layer_class == recurrent.LSTM:
        assert len(layer.losses) == 12
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号