def __init__(self, saved_model_dir, input_schema, exclude_outputs,
tf_config):
self.saved_model_dir = saved_model_dir
self.session = tf.Session(graph=tf.Graph(), config=tf_config)
with self.session.graph.as_default():
with tf.Session(config=tf_config):
inputs, outputs = saved_transform_io.partially_apply_saved_transform(
saved_model_dir, {})
self.session.run(tf.tables_initializer())
input_schema_keys = input_schema.column_schemas.keys()
extra_input_keys = set(input_schema_keys).difference(inputs.keys())
if extra_input_keys:
raise ValueError('Input schema contained keys not in graph: %s' %
input_schema_keys)
extra_output_keys = set(exclude_outputs).difference(outputs.keys())
if extra_output_keys:
raise ValueError('Excluded outputs contained keys not in graph: %s' %
exclude_outputs)
non_excluded_output_keys = set(outputs.keys()).difference(
exclude_outputs)
self.inputs = {key: inputs[key] for key in input_schema_keys}
self.outputs = {key: outputs[key] for key in non_excluded_output_keys}
python类tables_initializer()的实例源码
def test_table_roundtrip(self):
export_path = os.path.join(tempfile.mkdtemp(), 'export')
with tf.Graph().as_default():
with tf.Session().as_default() as session:
input_string = tf.placeholder(tf.string)
# Map string through a table, in this case based on a constant tensor.
table = lookup.index_table_from_tensor(
tf.constant(['cat', 'dog', 'giraffe']))
output = table.lookup(input_string)
inputs = {'input': input_string}
outputs = {'output': output}
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
with tf.Graph().as_default():
with tf.Session().as_default() as session:
# Using a computed input gives confidence that the graphs are fused.
input_string = tf.constant('dog')
inputs = {'input': input_string}
outputs = saved_transform_io.apply_saved_transform(export_path, inputs)
session.run(tf.tables_initializer())
result = session.run(outputs['output'])
self.assertEqual(1, result)
def _createTestInferModel(
self, m_creator, hparams, sess, init_global_vars=False):
infer_mode = tf.contrib.learn.ModeKeys.INFER
(infer_iterator, src_vocab_table,
tgt_vocab_table, reverse_tgt_vocab_table) = (
common_test_utils.create_test_iterator(hparams, infer_mode))
infer_m = m_creator(
hparams,
infer_mode,
infer_iterator,
src_vocab_table,
tgt_vocab_table,
reverse_tgt_vocab_table,
scope='dynamic_seq2seq')
if init_global_vars:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(infer_iterator.initializer)
return infer_m
def __init__(self, model_path, embedding_size, language, nlp):
# Step 1: restore the meta graph
with tf.Graph().as_default() as graph:
saver = tf.train.import_meta_graph(model_path + "model.ckpt.meta")
self.graph = graph
# get tensors for inputs and outputs by name
self.decoder_prediction = graph.get_tensor_by_name('decoder_prediction:0')
self.intent = graph.get_tensor_by_name('intent:0')
self.words_inputs = graph.get_tensor_by_name('words_inputs:0')
self.encoder_inputs_actual_length = graph.get_tensor_by_name('encoder_inputs_actual_length:0')
# redefine the py_func that is not serializable
def static_wrapper(words):
return spacy_wrapper(embedding_size, language, nlp, words)
after_py_func = tf.py_func(static_wrapper, [self.words_inputs], tf.float32, stateful=False)
# Step 2: restore weights
self.sess = tf.Session()
self.sess.run(tf.tables_initializer())
saver.restore(self.sess, model_path + "model.ckpt")
def _createTestInferModel(
self, m_creator, hparams, sess, init_global_vars=False):
infer_mode = tf.contrib.learn.ModeKeys.INFER
infer_iterator, src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table = (
common_test_utils.create_test_iterator(hparams, infer_mode))
infer_m = m_creator(
hparams,
infer_mode,
infer_iterator,
src_vocab_table,
tgt_vocab_table,
reverse_tgt_vocab_table,
scope='dynamic_seq2seq')
if init_global_vars:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(infer_iterator.initializer)
return infer_m
def create_or_load_model(model, model_dir, session, out_dir, name):
"""Create translation model and initialize or load parameters in session."""
start_time = time.time()
latest_ckpt = tf.train.latest_checkpoint(model_dir)
if latest_ckpt:
model.saver.restore(session, latest_ckpt)
utils.print_out(
" loaded %s model parameters from %s, time %.2fs" %
(name, latest_ckpt, time.time() - start_time))
else:
utils.print_out(" created %s model with fresh parameters, time %.2fs." %
(name, time.time() - start_time))
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
global_step = model.global_step.eval(session=session)
return model, global_step
def _run_graph(self, analysis_path, features, schema, stats, predict_data):
"""Runs the preprocessing graph.
Args:
analysis_path: path to folder containing analysis output. Should contain
the stats file.
features: features dict
schema: schema list
stats: stats dict
predict_data: list of csv strings
"""
stats = {'column_stats': {}}
with tf.Graph().as_default():
with tf.Session().as_default() as session:
outputs, labels, inputs = feature_transforms.build_csv_serving_tensors_for_transform_step(
analysis_path, features, schema, stats, keep_target=False)
feed_inputs = {inputs['csv_example']: predict_data}
session.run(tf.tables_initializer())
result = session.run(outputs, feed_dict=feed_inputs)
return result
def start_bundle(self, element=None):
"""Build the transfromation graph once."""
import tensorflow as tf
from trainer import feature_transforms
g = tf.Graph()
session = tf.Session(graph=g)
# Build the transformation graph
with g.as_default():
transformed_features, _, placeholders = (
feature_transforms.build_csv_serving_tensors_for_transform_step(
analysis_path=self._analysis_output_dir,
features=self._features,
schema=self._schema,
stats=self._stats,
keep_target=True))
session.run(tf.tables_initializer())
self._session = session
self._transformed_features = transformed_features
self._input_placeholder_tensor = placeholders['csv_example']
def start_bundle(self, element=None):
"""Build the transfromation graph once."""
import tensorflow as tf
from trainer import feature_transforms
g = tf.Graph()
session = tf.Session(graph=g)
# Build the transformation graph
with g.as_default():
transformed_features, _, placeholders = (
feature_transforms.build_csv_serving_tensors_for_transform_step(
analysis_path=self._analysis_output_dir,
features=self._features,
schema=self._schema,
stats=self._stats,
keep_target=True))
session.run(tf.tables_initializer())
self._session = session
self._transformed_features = transformed_features
self._input_placeholder_tensor = placeholders['csv_example']
def _test_pipeline(self, mode, params=None):
"""Helper function to test the full model pipeline.
"""
# Create source and target example
source_len = self.sequence_length + 5
target_len = self.sequence_length + 10
source = " ".join(np.random.choice(self.vocab_list, source_len))
target = " ".join(np.random.choice(self.vocab_list, target_len))
sources_file, targets_file = test_utils.create_temp_parallel_data(
sources=[source], targets=[target])
# Build model graph
model = self.create_model(mode, params)
input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
params={
"source_files": [sources_file.name],
"target_files": [targets_file.name]
},
mode=mode)
input_fn = training_utils.create_input_fn(
pipeline=input_pipeline_, batch_size=self.batch_size)
features, labels = input_fn()
fetches = model(features, labels, None)
fetches = [_ for _ in fetches if _ is not None]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
fetches_ = sess.run(fetches)
sources_file.close()
targets_file.close()
return model, fetches_
def test_without_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
vocab.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
np.testing.assert_array_equal(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
def test_with_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_counts = [100, 200, 300]
vocab_file = test_utils.create_temporary_vocab_file(vocab_list,
vocab_counts)
vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
vocab.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
np.testing.assert_array_equal(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
counts = word_to_count_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
counts = sess.run(counts)
np.testing.assert_array_equal(counts, [100, 200, 300, -1, -1])
def test_sampling(self):
hook = hooks.TrainSampleHook(
params={"every_n_steps": 10}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
#pylint: disable=W0212
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000000.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 0",
readfile.read().decode("utf-8"))
# Should not trigger for step 9
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000009.txt")
self.assertFalse(os.path.exists(outfile))
# Should trigger for step 10
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000010.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 10",
readfile.read().decode("utf-8"))
def _run_eval(self):
"""Run model evaluation and generate summaries."""
coord = tf.train.Coordinator(clean_stop_exception_types=(
tf.errors.CancelledError, tf.errors.OutOfRangeError))
with tf.Session(graph=self._graph) as session:
# Restores previously saved variables from latest checkpoint
self._saver.restore(session, self._latest_checkpoint)
session.run([tf.tables_initializer(), tf.local_variables_initializer()])
tf.train.start_queue_runners(coord=coord, sess=session)
train_step = session.run(self._gs)
tf.logging.info('Starting evaluation')
d = {key: [] for key in ['loss', 'accuracy', 'dice_coefficient', 'hausdorff_distance',
'average_symmetric_surface_distance']}
with coord.stop_on_exception():
while not coord.should_stop():
metric_dict = session.run(self._metric_dict)
prediction = metric_dict.pop('prediction')
ground_truth = metric_dict.pop('ground_truth')
d['loss'].append(metric_dict.pop('loss'))
d['accuracy'].append(metric_dict.pop('accuracy'))
d['dice_coefficient'].append(metric_dict.pop('dice_coefficient'))
d['hausdorff_distance'].append(hd(prediction, ground_truth))
d['average_symmetric_surface_distance'].append(assd(prediction, ground_truth))
# Save histogram, mean and std for each variable
for key, value in d.iteritems():
self.logger.log_histogram(tag=key, values=value, step=train_step, bins=15)
self.logger.log_random_variable(tag='eval_'+key, var=value, step=train_step)
tf.logging.info('Finished evaluation')
def _run_eval(self):
"""Run model evaluation and generate summaries."""
coord = tf.train.Coordinator(clean_stop_exception_types=(
tf.errors.CancelledError, tf.errors.OutOfRangeError))
with tf.Session(graph=self._graph) as session:
# Restores previously saved variables from latest checkpoint
self._saver.restore(session, self._latest_checkpoint)
session.run([
tf.tables_initializer(),
tf.local_variables_initializer()
])
tf.train.start_queue_runners(coord=coord, sess=session)
train_step = session.run(self._gs)
tf.logging.info('Starting Evaluation For Step: {}'.format(train_step))
with coord.stop_on_exception():
eval_step = 0
while not coord.should_stop() and (self._eval_steps is None or
eval_step < self._eval_steps):
summaries, final_values, _ = session.run(
[self._summary_op, self._final_ops_dict, self._eval_ops])
if eval_step % 100 == 0:
tf.logging.info("On Evaluation Step: {}".format(eval_step))
eval_step += 1
# Write the summaries
self._file_writer.add_summary(summaries, global_step=train_step)
self._file_writer.flush()
tf.logging.info(final_values)
def main_op():
init_local = variables.local_variables_initializer()
init_tables = lookup_ops.tables_initializer()
return control_flow_ops.group(init_local, init_tables)
def _test_pipeline(self, mode, params=None):
"""Helper function to test the full model pipeline.
"""
# Create source and target example
source_len = self.sequence_length + 5
target_len = self.sequence_length + 10
source = " ".join(np.random.choice(self.vocab_list, source_len))
target = " ".join(np.random.choice(self.vocab_list, target_len))
sources_file, targets_file = test_utils.create_temp_parallel_data(
sources=[source], targets=[target])
# Build model graph
model = self.create_model(mode, params)
input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
params={
"source_files": [sources_file.name],
"target_files": [targets_file.name]
},
mode=mode)
input_fn = training_utils.create_input_fn(
pipeline=input_pipeline_, batch_size=self.batch_size)
features, labels = input_fn()
fetches = model(features, labels, None)
fetches = [_ for _ in fetches if _ is not None]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
fetches_ = sess.run(fetches)
sources_file.close()
targets_file.close()
return model, fetches_
def test_without_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
vocab.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
np.testing.assert_array_equal(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
def test_with_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_counts = [100, 200, 300]
vocab_file = test_utils.create_temporary_vocab_file(vocab_list,
vocab_counts)
vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
vocab.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
np.testing.assert_array_equal(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
counts = word_to_count_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
counts = sess.run(counts)
np.testing.assert_array_equal(counts, [100, 200, 300, -1, -1])
def test_sampling(self):
hook = hooks.TrainSampleHook(
params={"every_n_steps": 10}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
#pylint: disable=W0212
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000000.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 0",
readfile.read().decode("utf-8"))
# Should not trigger for step 9
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000009.txt")
self.assertFalse(os.path.exists(outfile))
# Should trigger for step 10
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000010.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 10",
readfile.read().decode("utf-8"))
def assertSparseOutput(self, expected_indices, expected_values,
expected_shape, actual_sparse_tensor, close_values):
with tf.Session() as sess:
sess.run(tf.tables_initializer())
actual = actual_sparse_tensor.eval()
self.assertAllEqual(expected_indices, actual.indices)
self.assertAllEqual(expected_shape, actual.dense_shape)
if close_values:
self.assertAllClose(expected_values, actual.values)
else:
self.assertAllEqual(expected_values, actual.values)
def _createTestTrainModel(self, m_creator, hparams, sess):
train_mode = tf.contrib.learn.ModeKeys.TRAIN
train_iterator, src_vocab_table, tgt_vocab_table = (
common_test_utils.create_test_iterator(hparams, train_mode))
train_m = m_creator(
hparams,
train_mode,
train_iterator,
src_vocab_table,
tgt_vocab_table,
scope='dynamic_seq2seq')
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(train_iterator.initializer)
return train_m
def _createTestEvalModel(self, m_creator, hparams, sess):
eval_mode = tf.contrib.learn.ModeKeys.EVAL
eval_iterator, src_vocab_table, tgt_vocab_table = (
common_test_utils.create_test_iterator(hparams, eval_mode))
eval_m = m_creator(
hparams,
eval_mode,
eval_iterator,
src_vocab_table,
tgt_vocab_table,
scope='dynamic_seq2seq')
sess.run(tf.tables_initializer())
sess.run(eval_iterator.initializer)
return eval_m
def load_model(model, ckpt, session, name):
start_time = time.time()
model.saver.restore(session, ckpt)
session.run(tf.tables_initializer())
utils.print_out(
" loaded %s model parameters from %s, time %.2fs" %
(name, ckpt, time.time() - start_time))
return model
def test_sampling(self):
hook = learner_hooks.TrainSampleHook(
params={"every_n_steps": 10}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000000.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 0",
readfile.read().decode("utf-8"))
# Should not trigger for step 9
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000009.txt")
self.assertFalse(os.path.exists(outfile))
# Should trigger for step 10
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000010.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 10",
readfile.read().decode("utf-8"))
def test_without_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_file = create_temporary_vocab_file(vocab_list)
vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
vocabulary.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
self.assertAllEqual(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
self.assertAllEqual(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
def test_with_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_counts = [100, 200, 300]
vocab_file = create_temporary_vocab_file(vocab_list,
vocab_counts)
vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
vocabulary.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
self.assertAllEqual(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
self.assertAllEqual(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
counts = word_to_count_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
counts = sess.run(counts)
self.assertAllEqual(counts, [100, 200, 300, -1, -1])
def _createTestTrainModel(self, m_creator, hparams, sess):
train_mode = tf.contrib.learn.ModeKeys.TRAIN
train_iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
hparams, train_mode)
train_m = m_creator(
hparams,
train_mode,
train_iterator,
src_vocab_table,
tgt_vocab_table,
scope='dynamic_seq2seq')
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(train_iterator.initializer)
return train_m
def _createTestEvalModel(self, m_creator, hparams, sess):
eval_mode = tf.contrib.learn.ModeKeys.EVAL
eval_iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
hparams, eval_mode)
eval_m = m_creator(
hparams,
eval_mode,
eval_iterator,
src_vocab_table,
tgt_vocab_table,
scope='dynamic_seq2seq')
sess.run(tf.tables_initializer())
sess.run(eval_iterator.initializer)
return eval_m
def build_init(self):
"""Builds the initialization sub-graph.
The default implementation creates an initialization op that initializes all variables,
locals for initialization, and another for all non-traininable variables and tables for local
initialization.
Initialization is run when the graph is first created, before training. Local initialization is
performed after a previously trained model is loaded.
Returns:
A tuple containing the init op and local init op to use to initialize the graph.
"""
init_op = tf.variables_initializer(tf.global_variables(), name='init')
# For some reason not all local variables are in the local variables collection, but some are in
# the global variables collection (such as those setup by reader ops).
# So in addition to initializing local variables in the local_init_op, we also initialize the
# set of variables in the global variables, that are not trainable.
# Just to add to the mix, tables are neither, and so must be explicitly included as well.
# All of these will be initialized after restoring from a checkpoint.
variables = tf.global_variables()
for trainable in tf.trainable_variables():
variables.remove(trainable)
local_init_op = tf.group(tf.variables_initializer(variables),
tf.variables_initializer(tf.local_variables()),
tf.tables_initializer(),
name='local_init_op')
# Add the local initialization op to the main op collection, which is looked up at model loading
# time, and is automatically invoked after it has been loaded.
tf.add_to_collection('saved_model_main_op', local_init_op)
return init_op, local_init_op