ae.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:workspace 作者: nojima 项目源码 文件源码
def train(dataset: DataSet, n_iter: int = 3000, batch_size: int = 25) -> Iterator[AutoEncoder]:
    n = dataset.size

    input_dimension = dataset.input.shape[1]
    hidden_dimension = 2
    model = AutoEncoder(input_dimension, hidden_dimension)

    optimizer = optimizers.Adam()
    optimizer.setup(model)

    for j in range(n_iter):
        shuffled = np.random.permutation(n)

        for i in range(0, n, batch_size):
            indices = shuffled[i:i+batch_size]
            x = Variable(dataset.input[indices])
            model.cleargrads()
            loss = model(x)
            loss.backward()
            optimizer.update()

        yield model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号