def __init__(self, opt):
# tutorials/09 - Image Captioning
# Build Models
self.grad_clip = opt.grad_clip
self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
opt.finetune, opt.cnn_type,
use_abs=opt.use_abs,
no_imgnorm=opt.no_imgnorm)
self.txt_enc = EncoderText(opt.vocab_size, opt.word_dim,
opt.embed_size, opt.num_layers,
use_abs=opt.use_abs)
if torch.cuda.is_available():
self.img_enc.cuda()
self.txt_enc.cuda()
cudnn.benchmark = True
# Loss and Optimizer
self.criterion = ContrastiveLoss(margin=opt.margin,
measure=opt.measure,
max_violation=opt.max_violation)
params = list(self.txt_enc.parameters())
params += list(self.img_enc.fc.parameters())
if opt.finetune:
params += list(self.img_enc.cnn.parameters())
self.params = params
self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
self.Eiters = 0
评论列表
文章目录