def train_epochs(model, loss_fn, init_lr, model_dir):
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
os.makedirs(model_dir)
optimizer = optim.Adam(model.parameters(), lr = init_lr) # setup the optimizer
learning_rate = init_lr
max_iter = 5
start_halfing_iter = 2
halfing_factor = 0.1
count = 0
half_flag = False
while count < max_iter:
count += 1
if count >= start_halfing_iter:
half_flag = True
print ("Starting epoch", count)
if half_flag:
learning_rate *= halfing_factor
adjust_learning_rate(optimizer, halfing_factor) # decay learning rate
model_path = model_dir + '/epoch' + str(count) + '_lr' + str(learning_rate) + '.pkl'
train_one_epoch(model, loss_fn, optimizer) # train one epoch
torch.save(model.state_dict(), model_path)
print ("End training")
评论列表
文章目录