def train(epoch):
color_model.train()
try:
for batch_idx, (data, classes) in enumerate(train_loader):
messagefile = open('./message.txt', 'a')
original_img = data[0].unsqueeze(1).float()
img_ab = data[1].float()
if have_cuda:
original_img = original_img.cuda()
img_ab = img_ab.cuda()
classes = classes.cuda()
original_img = Variable(original_img)
img_ab = Variable(img_ab)
classes = Variable(classes)
optimizer.zero_grad()
class_output, output = color_model(original_img, original_img)
ems_loss = torch.pow((img_ab - output), 2).sum() / torch.from_numpy(np.array(list(output.size()))).prod()
cross_entropy_loss = 1/300 * F.cross_entropy(class_output, classes)
loss = ems_loss + cross_entropy_loss
lossmsg = 'loss: %.9f\n' % (loss.data[0])
messagefile.write(lossmsg)
ems_loss.backward(retain_variables=True)
cross_entropy_loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
message = 'Train Epoch:%d\tPercent:[%d/%d (%.0f%%)]\tLoss:%.9f\n' % (
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0])
messagefile.write(message)
torch.save(color_model.state_dict(), 'colornet_params.pkl')
messagefile.close()
# print('Train Epoch: {}[{}/{}({:.0f}%)]\tLoss: {:.9f}\n'.format(
# epoch, batch_idx * len(data), len(train_loader.dataset),
# 100. * batch_idx / len(train_loader), loss.data[0]))
except Exception:
logfile = open('log.txt', 'w')
logfile.write(traceback.format_exc())
logfile.close()
finally:
torch.save(color_model.state_dict(), 'colornet_params.pkl')
评论列表
文章目录