def construct_graph(self):
# Set the random seed
torch.manual_seed(cfg.RNG_SEED)
# Build the main computation graph
self.net.create_architecture(self.imdb.num_classes, tag='default',
anchor_scales=cfg.ANCHOR_SCALES,
anchor_ratios=cfg.ANCHOR_RATIOS)
# Define the loss
# loss = layers['total_loss']
# Set learning rate and momentum
lr = cfg.TRAIN.LEARNING_RATE
params = []
for key, value in dict(self.net.named_parameters()).items():
if value.requires_grad:
if 'bias' in key:
params += [{'params':[value],'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
else:
params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
# Write the train and validation information to tensorboard
self.writer = tb.writer.FileWriter(self.tbdir)
self.valwriter = tb.writer.FileWriter(self.tbvaldir)
return lr, self.optimizer
评论列表
文章目录