testing_utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ntm_keras 作者: flomlo 项目源码 文件源码
def train_model(model, epochs=10, min_size=5, max_size=20, callbacks=None, verboose=False):
    input_dim = model.input_dim
    output_dim = model.output_dim
    batch_size = model.batch_size

    sample_generator = get_sample(batch_size=batch_size, in_bits=input_dim, out_bits=output_dim,
                                                max_size=max_size, min_size=min_size)
    if verboose:
        for j in range(epochs):
            model.fit_generator(sample_generator, steps_per_epoch=10, epochs=j+1, callbacks=callbacks, initial_epoch=j)
            print("currently at epoch {0}".format(j+1))
            for i in [5,10,20,40]:
                test_model(model, sequence_length=i, verboose=True)
    else:
        model.fit_generator(sample_generator, steps_per_epoch=10, epochs=epochs, callbacks=callbacks)

    print("done training")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号