test_save_read.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:paysage 作者: drckf 项目源码 文件源码
def test_grbm_reload():
    vis_layer = layers.BernoulliLayer(num_vis)
    hid_layer = layers.GaussianLayer(num_hid)
    # create some extrinsics
    grbm = model.Model([vis_layer, hid_layer])
    with tempfile.NamedTemporaryFile() as file:
        # save the model
        store = pandas.HDFStore(file.name, mode='w')
        grbm.save(store)
        store.close()
        # reload
        store = pandas.HDFStore(file.name, mode='r')
        grbm_reload = model.Model.from_saved(store)
        store.close()
    # check the two models are consistent
    vis_data = vis_layer.random((num_samples, num_vis))
    data_state = model_utils.State.from_visible(vis_data, grbm)
    dropout_scale = model_utils.State.dropout_rescale(grbm)
    vis_orig = grbm.deterministic_iteration(1, data_state, dropout_scale).units[0]
    vis_reload = grbm_reload.deterministic_iteration(1, data_state, dropout_scale).units[0]
    assert be.allclose(vis_orig, vis_reload)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号