utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def joint_train(dbm, lr = 1e-3, epoch = 100, batch_size = 50, input_data = None, weight_decay = 0, k_positive=10, k_negative=10, alpha = [1e-1,1e-1,1]):
    u1 = nn.Parameter(torch.zeros(1))
    u2 = nn.Parameter(torch.zeros(1))
    # optimizer = optim.Adam(dbm.parameters(), lr = lr, weight_decay = weight_decay)
    optimizer = optim.SGD(dbm.parameters(), lr = lr, momentum = 0.5)
    train_set = torch.utils.data.dataset.TensorDataset(input_data, torch.zeros(input_data.size()[0]))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)
    optimizer_u = optim.Adam([u1,u2], lr = lr/1000, weight_decay = weight_decay)
    for _ in range(epoch):
        print("training epoch %i with u1 = %.4f, u2 = %.4f"%(_, u1.data.numpy()[0], u2.data.numpy()[0]))
        for batch_idx, (data, target) in enumerate(train_loader):
            data = Variable(data)
            positive_phase, negative_phase= dbm(v_input = data, k_positive = k_positive, k_negative=k_negative, greedy = False)
            loss = energy(dbm = dbm, layer = positive_phase) - energy(dbm = dbm, layer = negative_phase)+alpha[0] * torch.norm(torch.norm(dbm.W[0],2,1)-u1.repeat(dbm.W[0].size()[0],1))**2 + alpha[1]*torch.norm(torch.norm(dbm.W[1],2,1)-u2.repeat(dbm.W[1].size()[0],1))**2 + alpha[2] * (u1 - u2)**2
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            optimizer_u.step()
            optimizer_u.zero_grad()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号