def generative_fine_tune(dbn, lr = 1e-2, epoch = 100, batch_size = 50, input_data = None, CD_k = 1, optimization_method = "Adam", momentum = 0, weight_decay = 0, test_input = None):
if optimization_method == "RMSprop":
optimizer = optim.RMSprop(dbn.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
elif optimization_method == "SGD":
optimizer = optim.SGD(dbn.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
elif optimization_method == "Adam":
optimizer = optim.Adam(dbn.parameters(), lr = lr, weight_decay = weight_decay)
for i in dbn.parameters():
i.mean().backward()
train_set = torch.utils.data.dataset.TensorDataset(input_data, torch.zeros(input_data.size()[0]))
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)
for i in range(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
sleep_wake(dbn = dbn, optimizer = optimizer, lr = lr, CD_k = CD_k, v = data, batch_size = batch_size)
if not (type(test_input) == type(None)):
print("fine tune", i, ais_dbn.logp_ais(self, test_input, step = 1000, M_Z = 20, M_IS = 100, parallel = True))
utils.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录