def __init__(self, model, *args, **kwargs):
""""""
if args:
if len(args) > 1:
raise TypeError('Parser takes at most one argument')
kwargs['name'] = kwargs.pop('name', model.__name__)
super(Network, self).__init__(*args, **kwargs)
if not os.path.isdir(self.save_dir):
os.mkdir(self.save_dir)
with open(os.path.join(self.save_dir, 'config.cfg'), 'w') as f:
self._config.write(f)
self._global_step = tf.Variable(0., trainable=False)
self._global_epoch = tf.Variable(0., trainable=False)
self._model = model(self._config, global_step=self.global_step)
self._vocabs = []
vocab_files = [(self.word_file, 1, 'Words'),
(self.tag_file, [3, 4], 'Tags'),
(self.rel_file, 7, 'Rels')]
for i, (vocab_file, index, name) in enumerate(vocab_files):
vocab = Vocab(vocab_file, index, self._config,
name=name,
cased=self.cased if not i else True,
use_pretrained=(not i),
global_step=self.global_step)
self._vocabs.append(vocab)
self._trainset = Dataset(self.train_file, self._vocabs, model, self._config, name='Trainset')
self._validset = Dataset(self.valid_file, self._vocabs, model, self._config, name='Validset')
self._testset = Dataset(self.test_file, self._vocabs, model, self._config, name='Testset')
self._ops = self._gen_ops()
self._save_vars = filter(lambda x: u'Pretrained' not in x.name, tf.all_variables())
self.history = {
'train_loss': [],
'train_accuracy': [],
'valid_loss': [],
'valid_accuracy': [],
'test_acuracy': 0
}
return
#=============================================================
评论列表
文章目录