trainer.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def __init__(self, args, dataloader_train, dataloader_dev, char_embedding_config, word_embedding_config,
                 sentence_encoding_config, pair_encoding_config, self_matching_config, pointer_config):

        # for validate
        expected_version = "1.1"
        with open(args.dev_json) as dataset_file:
            dataset_json = json.load(dataset_file)
            if dataset_json['version'] != expected_version:
                print('Evaluation expects v-' + expected_version +
                      ', but got dataset with v-' + dataset_json['version'],
                      file=sys.stderr)
            self.dev_dataset = dataset_json['data']

        self.dataloader_train = dataloader_train
        self.dataloader_dev = dataloader_dev

        self.model = RNet.Model(args, char_embedding_config, word_embedding_config, sentence_encoding_config,
                                pair_encoding_config, self_matching_config, pointer_config)
        self.parameters_trainable = list(
            filter(lambda p: p.requires_grad, self.model.parameters()))
        self.optimizer = optim.Adadelta(self.parameters_trainable, rho=0.95)
        self.best_f1 = 0
        self.step = 0
        self.start_epoch = args.start_epoch
        self.name = args.name
        self.start_time = datetime.datetime.now().strftime('%b-%d_%H-%M')

        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_f1 = checkpoint['best_f1']
                self.name = checkpoint['name']
                self.step = checkpoint['step']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.start_time = checkpoint['start_time']

                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
            else:
                raise ValueError("=> no checkpoint found at '{}'".format(args.resume))
        else:
            self.name += "_" + self.start_time

        # use which device
        if torch.cuda.is_available():
            self.model = self.model.cuda(args.device_id)
        else:
            self.model = self.model.cpu()

        self.loss_fn = torch.nn.CrossEntropyLoss()

        configure("log/%s" % (self.name), flush_secs=5)
        self.checkpoint_path = os.path.join(args.checkpoint_path, self.name)
        make_dirs(self.checkpoint_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号