utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def generative_fine_tune(dbn, lr = 1e-2, epoch = 100, batch_size = 50, input_data = None, CD_k = 1, optimization_method = "Adam", momentum = 0, weight_decay = 0, test_input = None):

    if optimization_method == "RMSprop":
        optimizer = optim.RMSprop(dbn.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
    elif optimization_method == "SGD":
        optimizer = optim.SGD(dbn.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
    elif optimization_method == "Adam":
        optimizer = optim.Adam(dbn.parameters(), lr = lr, weight_decay = weight_decay)   

    for i in dbn.parameters():
        i.mean().backward()

    train_set = torch.utils.data.dataset.TensorDataset(input_data, torch.zeros(input_data.size()[0]))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)

    for i in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):

            sleep_wake(dbn = dbn, optimizer = optimizer, lr = lr, CD_k = CD_k, v = data, batch_size = batch_size)

        if not (type(test_input) == type(None)):

            print("fine tune", i, ais_dbn.logp_ais(self, test_input, step = 1000, M_Z = 20, M_IS = 100, parallel = True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号