test_highway_network_helper.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pepnet 作者: hammerlab 项目源码 文件源码
def test_highway_layers():
    n_highway_layers = 5
    x = Input(shape=(8,), dtype="int32")
    v = Embedding(input_dim=2, output_dim=10)(x)
    v = Flatten()(v)
    assert hasattr(v, "_keras_shape")
    v = highway_layers(v, n_layers=n_highway_layers)
    output = Dense(1)(v)
    model = Model(inputs=[x], outputs=[output])
    assert len(model.layers) > n_highway_layers * 3
    x = np.array([
        [1] + [0] * 7,
        [0] * 8,
        [0] * 7 + [1]])
    y = np.array([0, 1, 0])
    model.compile("rmsprop", "mse")
    model.fit(x, y, epochs=10)
    pred = model.predict(x)
    mean_diff = np.abs(pred - y).mean()
    assert mean_diff < 0.5, pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号