def __init__(self, network, dictionary=None, seq_maxlen=25,
clip_gradients=0.0, tensorboard_verbose=0,
tensorboard_dir="/tmp/tflearn_logs/",
checkpoint_path=None, max_checkpoints=None,
session=None):
assert isinstance(network, tf.Tensor), "'network' arg is not a Tensor!"
self.net = network
self.train_ops = tf.get_collection(tf.GraphKeys.TRAIN_OPS)
self.trainer = Trainer(self.train_ops,
clip_gradients=clip_gradients,
tensorboard_dir=tensorboard_dir,
tensorboard_verbose=tensorboard_verbose,
checkpoint_path=checkpoint_path,
max_checkpoints=max_checkpoints,
session=session)
self.session = self.trainer.session
self.inputs = tf.get_collection(tf.GraphKeys.INPUTS)
self.targets = tf.get_collection(tf.GraphKeys.TARGETS)
self.predictor = Evaluator([self.net],
session=self.session)
self.dic = dictionary
self.rev_dic = reverse_dictionary(dictionary)
self.seq_maxlen = seq_maxlen
评论列表
文章目录