def joint_train(dbm, lr = 1e-3, epoch = 100, batch_size = 50, input_data = None, weight_decay = 0, k_positive=10, k_negative=10, alpha = [1e-1,1e-1,1]):
u1 = nn.Parameter(torch.zeros(1))
u2 = nn.Parameter(torch.zeros(1))
# optimizer = optim.Adam(dbm.parameters(), lr = lr, weight_decay = weight_decay)
optimizer = optim.SGD(dbm.parameters(), lr = lr, momentum = 0.5)
train_set = torch.utils.data.dataset.TensorDataset(input_data, torch.zeros(input_data.size()[0]))
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)
optimizer_u = optim.Adam([u1,u2], lr = lr/1000, weight_decay = weight_decay)
for _ in range(epoch):
print("training epoch %i with u1 = %.4f, u2 = %.4f"%(_, u1.data.numpy()[0], u2.data.numpy()[0]))
for batch_idx, (data, target) in enumerate(train_loader):
data = Variable(data)
positive_phase, negative_phase= dbm(v_input = data, k_positive = k_positive, k_negative=k_negative, greedy = False)
loss = energy(dbm = dbm, layer = positive_phase) - energy(dbm = dbm, layer = negative_phase)+alpha[0] * torch.norm(torch.norm(dbm.W[0],2,1)-u1.repeat(dbm.W[0].size()[0],1))**2 + alpha[1]*torch.norm(torch.norm(dbm.W[1],2,1)-u2.repeat(dbm.W[1].size()[0],1))**2 + alpha[2] * (u1 - u2)**2
loss.backward()
optimizer.step()
optimizer.zero_grad()
optimizer_u.step()
optimizer_u.zero_grad()
utils.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录