def __init__(self, corpus, **opts):
self.corpus = corpus
self.opts = opts
self.global_step = get_or_create_global_step()
self.increment_global_step_op = tf.assign(self.global_step, self.global_step + 1, name="increment_global_step")
self.corpus_size = get_corpus_size(self.corpus["train"])
self.corpus_size_valid = get_corpus_size(self.corpus["valid"])
self.word2idx, self.idx2word = build_vocab(self.corpus["train"])
self.vocab_size = len(self.word2idx)
self.generator_template = tf.make_template(GENERATOR_PREFIX, generator)
self.discriminator_template = tf.make_template(DISCRIMINATOR_PREFIX, discriminator)
self.enqueue_data, _, source, target, sequence_length = \
prepare_data(self.corpus["train"], self.word2idx, num_threads=7, **self.opts)
# TODO: option to either do pretrain or just generate?
self.g_tensors_pretrain = self.generator_template(
source, target, sequence_length, self.vocab_size, **self.opts)
self.enqueue_data_valid, self.input_ph, source_valid, target_valid, sequence_length_valid = \
prepare_data(self.corpus["valid"], self.word2idx, num_threads=1, **self.opts)
self.g_tensors_pretrain_valid = self.generator_template(
source_valid, target_valid, sequence_length_valid, self.vocab_size, **self.opts)
self.decoder_fn = prepare_custom_decoder(
sequence_length, self.g_tensors_pretrain.embedding_matrix, self.g_tensors_pretrain.output_projections)
self.g_tensors_fake = self.generator_template(
source, target, sequence_length, self.vocab_size, decoder_fn=self.decoder_fn, **self.opts)
self.g_tensors_fake_valid = self.generator_template(
source_valid, target_valid, sequence_length_valid, self.vocab_size, decoder_fn=self.decoder_fn, **self.opts)
# TODO: using the rnn outputs from pretraining as "real" instead of target embeddings (aka professor forcing)
self.d_tensors_real = self.discriminator_template(
self.g_tensors_pretrain.rnn_outputs, sequence_length, is_real=True, **self.opts)
# TODO: check to see if sequence_length is correct
self.d_tensors_fake = self.discriminator_template(
self.g_tensors_fake.rnn_outputs, None, is_real=False, **self.opts)
self.g_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=GENERATOR_PREFIX)
self.d_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=DISCRIMINATOR_PREFIX)
评论列表
文章目录