def testSessionLogStartMessageDiscardsExpiredEvents(self):
"""Test that SessionLog.START message discards expired events.
This discard logic is preferred over the out-of-order step discard logic,
but this logic can only be used for event protos which have the SessionLog
enum, which was introduced to event.proto for file_version >= brain.Event:2.
"""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2'))
gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=400, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=202, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=203, value=20)
slog = tf.SessionLog(status=tf.SessionLog.START)
gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog))
acc.Reload()
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200])
self.assertEqual([x.step for x in acc.Tensors('s2')], [])
python类SessionLog()的实例源码
def testSessionLogStartMessageDiscardsExpiredEvents(self):
"""Test that SessionLog.START message discards expired events.
This discard logic is preferred over the out-of-order step discard logic,
but this logic can only be used for event protos which have the SessionLog
enum, which was introduced to event.proto for file_version >= brain.Event:2.
"""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2'))
gen.AddScalar('s1', wall_time=1, step=100, value=20)
gen.AddScalar('s1', wall_time=1, step=200, value=20)
gen.AddScalar('s1', wall_time=1, step=300, value=20)
gen.AddScalar('s1', wall_time=1, step=400, value=20)
gen.AddScalar('s2', wall_time=1, step=202, value=20)
gen.AddScalar('s2', wall_time=1, step=203, value=20)
slog = tf.SessionLog(status=tf.SessionLog.START)
gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog))
acc.Reload()
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200])
self.assertEqual([x.step for x in acc.Scalars('s2')], [])
def initialize_tf_variables(self):
"""
Initialize tensorflow variables (either initializes them from scratch or restores from checkpoint).
:return: updated TeLL session
"""
session = self.tf_session
checkpoint = self.workspace.get_checkpoint()
#
# Initialize or load variables
#
with Timer(name="Initializing variables"):
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
if checkpoint is not None:
# restore from checkpoint
self.tf_saver.restore(session, checkpoint)
# get step number from checkpoint
step = session.run(self.__global_step_placeholder) + 1
self.global_step = step
# reopen summaries
for _, summary in self.tf_summaries.items():
summary.reopen()
summary.add_session_log(tf.SessionLog(status=tf.SessionLog.START), global_step=step)
print("Resuming from checkpoint '{}' at iteration {}".format(checkpoint, step))
else:
for _, summary in self.tf_summaries.items():
summary.add_graph(session.graph)
return self
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
"""Tests that events are discarded after a restart is detected.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items with the same tag
that are outdated.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
acc.Reload()
## Check that number of items are what they should be
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300])
gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 from s1
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
"""Tests that event discards after restart, only affect the misordered tag.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items that are outdated, but
only for the out of order tag. Other tags should remain unaffected.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=301, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s3', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 for s1
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
## Check that s1 discards do not affect s2 (written before out-of-order)
## or s3 (written after out-of-order).
## i.e. check that only events from the out of order tag are discarded
self.assertEqual([x.step for x in acc.Tensors('s2')], [101, 201, 301])
self.assertEqual([x.step for x in acc.Tensors('s3')], [101])
def testSessionLogSummaries(self):
data = [
{
'session_log': tf.SessionLog(status=tf.SessionLog.START),
'step': 0
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
'step': 1
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
'step': 2
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT),
'step': 3
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.STOP),
'step': 4
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.START),
'step': 5
},
{
'session_log': tf.SessionLog(status=tf.SessionLog.STOP),
'step': 6
},
]
self._WriteScalarSummaries(data)
units = efi.get_inspection_units(self.logdir)
self.assertEqual(1, len(units))
printable = efi.get_dict_to_print(units[0].field_to_obs)
self.assertEqual(printable['sessionlog:start']['steps'], [0, 5])
self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6])
self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
"""Tests that events are discarded after a restart is detected.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items with the same tag
that are outdated.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalar('s1', wall_time=1, step=100, value=20)
gen.AddScalar('s1', wall_time=1, step=200, value=20)
gen.AddScalar('s1', wall_time=1, step=300, value=20)
acc.Reload()
## Check that number of items are what they should be
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])
gen.AddScalar('s1', wall_time=1, step=101, value=20)
gen.AddScalar('s1', wall_time=1, step=201, value=20)
gen.AddScalar('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 from s1
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
def train():
"""
"""
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
xx_real = build_image_batch_reader(
FLAGS.x_images_dir_path, FLAGS.batch_size)
yy_real = build_image_batch_reader(
FLAGS.y_images_dir_path, FLAGS.batch_size)
image_pool = {}
model = build_cycle_gan(xx_real, yy_real, '')
summaries = build_summaries(model)
reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
if ckpt_source_path is not None:
tf.train.Saver().restore(session, ckpt_source_path)
# give up overlapped old data
step = session.run(model['step'])
reporter.add_session_log(
tf.SessionLog(status=tf.SessionLog.START),
global_step=step)
# make dataset reader work
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
while train_one_step(model, summaries, image_pool, reporter):
pass
coord.request_stop()
coord.join(threads)