def sequential(x, net, defaults = {}, name = '', reuse = None, var = {}, layers = {}):
layers = dict(list(layers.items()) + list(predefined_layers.items()))
y = x
logging.info('Building Sequential Network : %s', name)
with tf.variable_scope(name, reuse = reuse):
for i in range(len(net)):
ltype = net[i][0]
lcfg = net[i][1] if len(net[i]) == 2 else {}
lname = lcfg.get('name', ltype + str(i))
ldefs = defaults.get(ltype, {})
lcfg = dict(list(ldefs.items()) + list(lcfg.items()))
for k, v in list(lcfg.items()):
if isinstance(v, basestring) and v[0] == '$':
# print var, v
lcfg[k] = var[v[1:]]
y = layers[ltype](y, lname, **lcfg)
logging.info('\t %s \t %s', lname, y.get_shape().as_list())
return y
python类logging()的实例源码
def _module_info_from_proto_safe(module_info_def, import_scope=None):
"""Deserializes the `module_info_def` proto without raising exceptions.
Args:
module_info_def: An instance of `module_pb2.SonnetModule`.
import_scope: Optional `string`. Name scope to use.
Returns:
An instance of `ModuleInfo`.
"""
try:
return _module_info_from_proto(module_info_def, import_scope)
except Exception as e: # pylint: disable=broad-except
logging.warning(
"Error encountered when deserializing sonnet ModuleInfo:\n%s", str(e))
return None
# `to_proto` is already wrapped into a try...except externally but
# `from_proto` isn't. In order to minimize disruption, catch all the exceptions
# happening during `from_proto` and just log them.
def main(unused_args):
g = tf.Graph()
with g.as_default(), tf.device('/cpu:0'):
# Build the model for evaluation.
model = create_model(FLAGS, 'eval')
model.build()
with tf.Session() as sess:
# Start the queue runners.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Run evaluation on the latest checkpoint.
try:
for i in range(FLAGS.total_steps):
inspect_tensors(sess)
except Exception as e: # pylint: disable=broad-except
tf.logging.error("Evaluation failed.")
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=1)
def get_priming_melodies(self):
"""Runs a batch of training data through MelodyRNN model.
If the priming mode is 'random_midi', priming the q-network requires a
random training melody. Therefore this function runs a batch of data from
the training directory through the internal model, and the resulting
internal states of the LSTM are stored in a list. The next note in each
training melody is also stored in a corresponding list called
'priming_notes'. Therefore, to prime the model with a random melody, it is
only necessary to select a random index from 0 to batch_size-1 and use the
hidden states and note at that index as input to the model.
"""
(next_note_softmax,
self.priming_states, lengths) = self.q_network.run_training_batch()
# Get the next note that was predicted for each priming melody to be used
# in priming.
self.priming_notes = [0] * len(lengths)
for i in range(len(lengths)):
# Each melody has TRAIN_SEQUENCE_LENGTH outputs, but the last note is
# actually stored at lengths[i]. The rest is padding.
start_i = i * TRAIN_SEQUENCE_LENGTH
end_i = start_i + lengths[i] - 1
end_softmax = next_note_softmax[end_i, :]
self.priming_notes[i] = np.argmax(end_softmax)
tf.logging.info('Stored priming notes: %s', self.priming_notes)
def prime_internal_model(self, model):
"""Prime an internal model such as the q_network based on priming mode.
Args:
model: The internal model that should be primed.
Returns:
The first observation to feed into the model.
"""
model.state_value = model.get_zero_state()
if self.priming_mode == 'random_midi':
priming_idx = np.random.randint(0, len(self.priming_states))
model.state_value = np.reshape(
self.priming_states[priming_idx, :],
(1, model.cell.state_size))
priming_note = self.priming_notes[priming_idx]
next_obs = np.array(
rl_tuner_ops.make_onehot([priming_note], self.num_actions)).flatten()
tf.logging.debug(
'Feeding priming state for midi file %s and corresponding note %s',
priming_idx, priming_note)
elif self.priming_mode == 'single_midi':
model.prime_model()
next_obs = model.priming_note
elif self.priming_mode == 'random_note':
next_obs = self.get_random_note()
else:
tf.logging.warn('Error! Invalid priming mode. Priming with random note')
next_obs = self.get_random_note()
return next_obs
def reward_leap_up_back(self, action, resolving_leap_bonus=5.0,
leaping_twice_punishment=-5.0):
"""Applies punishment and reward based on the principle leap up leap back.
Large interval jumps (more than a fifth) should be followed by moving back
in the same direction.
Args:
action: One-hot encoding of the chosen action.
resolving_leap_bonus: Amount of reward dispensed for resolving a previous
leap.
leaping_twice_punishment: Amount of reward received for leaping twice in
the same direction.
Returns:
Float reward value.
"""
leap_outcome = self.detect_leap_up_back(action)
if leap_outcome == rl_tuner_ops.LEAP_RESOLVED:
tf.logging.debug('Leap resolved, awarding %s', resolving_leap_bonus)
return resolving_leap_bonus
elif leap_outcome == rl_tuner_ops.LEAP_DOUBLED:
tf.logging.debug('Leap doubled, awarding %s', leaping_twice_punishment)
return leaping_twice_punishment
else:
return 0.0
def testFinalCoreHasNoSizeWarning(self):
cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu]
rnn = snt.DeepRNN(cores, skip_connections=False)
with mock.patch.object(tf.logging, "warning") as mocked_logging_warning:
# This will produce a warning.
unused_output_size = rnn.output_size
self.assertTrue(mocked_logging_warning.called)
first_call_args = mocked_logging_warning.call_args[0]
self.assertTrue("final core %s does not have the "
".output_size field" in first_call_args[0])
self.assertEqual(first_call_args[2], 42)
def testNoSizeButAlreadyConnected(self):
batch_size = 16
cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu]
rnn = snt.DeepRNN(cores, skip_connections=False)
unused_output = rnn(tf.zeros((batch_size, 128)),
rnn.initial_state(batch_size=batch_size))
with mock.patch.object(tf.logging, "warning") as mocked_logging_warning:
output_size = rnn.output_size
# Correct size is automatically inferred.
self.assertEqual(output_size, tf.TensorShape([42]))
self.assertTrue(mocked_logging_warning.called)
first_call_args = mocked_logging_warning.call_args[0]
self.assertTrue("DeepRNN has been connected into the graph, "
"so inferred output size" in first_call_args[0])
def testWarning(self):
seq = snt.Sequential([snt.Linear(output_size=23),
snt.Linear(output_size=42)])
seq(tf.placeholder(dtype=tf.float32, shape=[2, 3]))
with mock.patch.object(tf.logging, "warning") as mocked_logging_warning:
self.assertEqual((), seq.get_variables())
self.assertTrue(mocked_logging_warning.called)
first_call_args = mocked_logging_warning.call_args[0]
self.assertTrue("will always return an empty tuple" in first_call_args[0])
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
"""Tests that events are discarded after a restart is detected.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items with the same tag
that are outdated.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
acc.Reload()
## Check that number of items are what they should be
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300])
gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 from s1
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
"""Tests that event discards after restart, only affect the misordered tag.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items that are outdated, but
only for the out of order tag. Other tags should remain unaffected.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
gen.AddScalarTensor('s2', wall_time=1, step=301, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s3', wall_time=1, step=101, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 for s1
self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
## Check that s1 discards do not affect s2 (written before out-of-order)
## or s3 (written after out-of-order).
## i.e. check that only events from the out of order tag are discarded
self.assertEqual([x.step for x in acc.Tensors('s2')], [101, 201, 301])
self.assertEqual([x.step for x in acc.Tensors('s3')], [101])
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
"""Tests that events are discarded after a restart is detected.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items with the same tag
that are outdated.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalar('s1', wall_time=1, step=100, value=20)
gen.AddScalar('s1', wall_time=1, step=200, value=20)
gen.AddScalar('s1', wall_time=1, step=300, value=20)
acc.Reload()
## Check that number of items are what they should be
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])
gen.AddScalar('s1', wall_time=1, step=101, value=20)
gen.AddScalar('s1', wall_time=1, step=201, value=20)
gen.AddScalar('s1', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300 from s1
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
"""Tests that event discards after restart, only affect the misordered tag.
If a step value is observed to be lower than what was previously seen,
this should force a discard of all previous items that are outdated, but
only for the out of order tag. Other tags should remain unaffected.
Only file versions < 2 use this out-of-order discard logic. Later versions
discard events based on the step value of SessionLog.START.
"""
warnings = []
self.stubs.Set(tf.logging, 'warn', warnings.append)
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
gen.AddScalar('s1', wall_time=1, step=100, value=20)
gen.AddScalar('s1', wall_time=1, step=200, value=20)
gen.AddScalar('s1', wall_time=1, step=300, value=20)
gen.AddScalar('s1', wall_time=1, step=101, value=20)
gen.AddScalar('s1', wall_time=1, step=201, value=20)
gen.AddScalar('s1', wall_time=1, step=301, value=20)
gen.AddScalar('s2', wall_time=1, step=101, value=20)
gen.AddScalar('s2', wall_time=1, step=201, value=20)
gen.AddScalar('s2', wall_time=1, step=301, value=20)
acc.Reload()
## Check that we have discarded 200 and 300
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
## Check that s1 discards do not affect s2
## i.e. check that only events from the out of order tag are discarded
self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301])
def reward_music_theory(self, action):
"""Computes cumulative reward for all music theory functions.
Args:
action: A one-hot encoding of the chosen action.
Returns:
Float reward value.
"""
reward = self.reward_key(action)
tf.logging.debug('Key: %s', reward)
prev_reward = reward
reward += self.reward_tonic(action)
if reward != prev_reward:
tf.logging.debug('Tonic: %s', reward)
prev_reward = reward
reward += self.reward_penalize_repeating(action)
if reward != prev_reward:
tf.logging.debug('Penalize repeating: %s', reward)
prev_reward = reward
reward += self.reward_penalize_autocorrelation(action)
if reward != prev_reward:
tf.logging.debug('Penalize autocorr: %s', reward)
prev_reward = reward
reward += self.reward_motif(action)
if reward != prev_reward:
tf.logging.debug('Reward motif: %s', reward)
prev_reward = reward
reward += self.reward_repeated_motif(action)
if reward != prev_reward:
tf.logging.debug('Reward repeated motif: %s', reward)
prev_reward = reward
# New rewards based on Gauldin's book, "A Practical Approach to Eighteenth
# Century Counterpoint"
reward += self.reward_preferred_intervals(action)
if reward != prev_reward:
tf.logging.debug('Reward preferred_intervals: %s', reward)
prev_reward = reward
reward += self.reward_leap_up_back(action)
if reward != prev_reward:
tf.logging.debug('Reward leap up back: %s', reward)
prev_reward = reward
reward += self.reward_high_low_unique(action)
if reward != prev_reward:
tf.logging.debug('Reward high low unique: %s', reward)
return reward
def restore_from_directory(self, directory=None, checkpoint_name=None,
reward_file_name=None):
"""Restores this model from a saved checkpoint.
Args:
directory: Path to directory where checkpoint is located. If
None, defaults to self.output_dir.
checkpoint_name: The name of the checkpoint within the
directory.
reward_file_name: The name of the .npz file where the stored
rewards are saved. If None, will not attempt to load stored
rewards.
"""
if directory is None:
directory = self.output_dir
if checkpoint_name is not None:
checkpoint_file = os.path.join(directory, checkpoint_name)
else:
tf.logging.info('Directory %s.', directory)
checkpoint_file = tf.train.latest_checkpoint(directory)
if checkpoint_file is None:
tf.logging.fatal('Error! Cannot locate checkpoint in the directory')
return
# TODO(natashamjaques): Remove print statement once tf.logging outputs
# to Jupyter notebooks (once the following issue is resolved:
# https://github.com/tensorflow/tensorflow/issues/3047)
print('Attempting to restore from checkpoint', checkpoint_file)
tf.logging.info('Attempting to restore from checkpoint %s', checkpoint_file)
self.saver.restore(self.session, checkpoint_file)
if reward_file_name is not None:
npz_file_name = os.path.join(directory, reward_file_name)
# TODO(natashamjaques): Remove print statement once tf.logging outputs
# to Jupyter notebooks (once the following issue is resolved:
# https://github.com/tensorflow/tensorflow/issues/3047)
print('Attempting to load saved reward values from file', npz_file_name)
tf.logging.info('Attempting to load saved reward values from file %s',
npz_file_name)
npz_file = np.load(npz_file_name)
self.rewards_batched = npz_file['train_rewards']
self.music_theory_rewards_batched = npz_file['train_music_theory_rewards']
self.note_rnn_rewards_batched = npz_file['train_note_rnn_rewards']
self.eval_avg_reward = npz_file['eval_rewards']
self.eval_avg_music_theory_reward = npz_file['eval_music_theory_rewards']
self.eval_avg_note_rnn_reward = npz_file['eval_note_rnn_rewards']
self.target_val_list = npz_file['target_val_list']
def run_epoch(session, config, graph, iterator, ops=None,
summary_writer=None, summary_prefix=None, saver=None):
"""Runs the model on the given data."""
if not ops:
ops = []
def should_monitor(step):
return step and c['monitoring_frequency'] and (step + 1) % c['monitoring_frequency'] == 0
def should_save(step):
return step and c['saving_frequency'] and (step + 1) % c['saving_frequency'] == 0
# Shortcuts, ugly but still increase readability
c = config
g = graph
m = Monitor(summary_writer, summary_prefix)
while g['step_number'].eval() < FLAGS.task * c['next_worker_delay']:
pass
# Statistics
for step, (inputs, lengths) in enumerate(iterator):
# Define what we feed
feed_dict = {g['inputs']: inputs,
g['input_lengths']: lengths}
# Define what we fetch
fetch = dict(g['observed'])
fetch['total_neg_loglikelihood'] = g['total_neg_loglikelihood']
fetch['total_correct'] = g['total_correct']
fetch['_ops'] = ops
# RUN!!!
r = session.run(fetch, feed_dict)
# Update the monitor accumulators
m.total_neg_loglikelihood += r['total_neg_loglikelihood']
m.total_correct += r['total_correct']
# We do not predict the first words, that's why
# batch_size has to subtracted from the total
m.steps += 1
m.words += sum(lengths) - c['batch_size']
m.sentences += c['batch_size']
m.words_including_padding += c['batch_size'] * len(inputs[0])
m.step_number = g['step_number'].eval()
m.learning_rate = float(g['learning_rate'].eval())
for key in g['observed']:
m.observed[key] += r[key]
if should_monitor(step):
tf.logging.info('monitor')
result = m.monitor()
if saver and should_save(step):
print("saved")
saver.save(session, os.path.join(FLAGS.train_path, 'model'))
if not should_monitor(step):
result = m.monitor()
if saver:
saver.save(session, os.path.join(FLAGS.train_path, 'model'))
return result
def main(_):
# Configuration.
num_unrolls = FLAGS.num_steps
if FLAGS.seed:
tf.set_random_seed(FLAGS.seed)
# Problem.
problem, net_config, net_assignments = util.get_config(FLAGS.problem,
FLAGS.path)
# Optimizer setup.
if FLAGS.optimizer == "Adam":
cost_op = problem()
problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
problem_reset = tf.variables_initializer(problem_vars)
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
update = optimizer.minimize(cost_op)
reset = [problem_reset, optimizer_reset]
elif FLAGS.optimizer == "L2L":
if FLAGS.path is None:
logging.warning("Evaluating untrained L2L optimizer")
optimizer = meta.MetaOptimizer(**net_config)
meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments)
_, update, reset, cost_op, _ = meta_loss
else:
raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))
with ms.MonitoredSession() as sess:
# Prevent accidental changes to the graph.
tf.get_default_graph().finalize()
total_time = 0
total_cost = 0
for _ in xrange(FLAGS.num_epochs):
# Training.
time, cost = util.run_epoch(sess, cost_op, [update], reset,
num_unrolls)
total_time += time
total_cost += cost
# Results.
util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
total_time, FLAGS.num_epochs)
def evaluate_model(sess, target_cross_entropy_losses, target_cross_entropy_loss_weights, global_step, summary_writer, summary_op):
"""Computes perplexity-per-word over the evaluation dataset.
Summaries and perplexity-per-word are written out to the eval directory.
Args:
sess: Session object.
model: Instance of ShowAndTellModel; the model to evaluate.
global_step: Integer; global step of the model checkpoint.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
# Log model summaries on a single batch.
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, global_step)
# Compute perplexity over the entire dataset.
num_eval_batches = int(
math.ceil(num_eval_examples / batch_size))
start_time = time.time()
sum_losses = 0.
sum_weights = 0.
for i in xrange(num_eval_batches):
cross_entropy_losses, weights = sess.run([
target_cross_entropy_losses,
target_cross_entropy_loss_weights
])
sum_losses += np.sum(cross_entropy_losses * weights)
sum_weights += np.sum(weights)
if not i % 100:
tf.logging.info("Computed losses for %d of %d batches.", i + 1,
num_eval_batches)
eval_time = time.time() - start_time
perplexity = math.exp(sum_losses / sum_weights)
tf.logging.info("Perplexity = %f (%.2g sec)", perplexity, eval_time)
# Log perplexity to the SummaryWriter.
summary = tf.Summary()
value = summary.value.add()
value.simple_value = perplexity
value.tag = "Perplexity"
summary_writer.add_summary(summary, global_step)
# Write the Events file to the eval directory.
summary_writer.flush()
tf.logging.info("Finished processing evaluation at global step %d.",
global_step)
def run_once(global_step, target_cross_entropy_losses, target_cross_entropy_loss_weights, saver, summary_writer, summary_op):
"""Evaluates the latest model checkpoint.
Args:
model: Instance of ShowAndTellModel; the model to evaluate.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
# The lastest ckpt
model_path = tf.train.latest_checkpoint(checkpoint_dir)
# print(model_path) # /home/dsigpu4/Samba/im2txt/model/train_tl/model.ckpt-20000
# exit()
if not model_path:
tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
checkpoint_dir)
return
with tf.Session() as sess:
# Load model from checkpoint.
tf.logging.info("Loading model from checkpoint: %s", model_path)
saver.restore(sess, model_path)
# global_step = tf.train.global_step(sess, model.global_step.name)
step = tf.train.global_step(sess, global_step.name)
tf.logging.info("Successfully loaded %s at global step = %d.",
# os.path.basename(model_path), global_step)
os.path.basename(model_path), step)
# if global_step < min_global_step:
if step < min_global_step:
# tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
tf.logging.info("Skipping evaluation. Global step = %d < %d", step,
min_global_step)
return
# Start the queue runners.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Run evaluation on the latest checkpoint.
try:
evaluate_model(
sess=sess,
target_cross_entropy_losses=target_cross_entropy_losses,
target_cross_entropy_loss_weights=target_cross_entropy_loss_weights,
global_step=step,
summary_writer=summary_writer,
summary_op=summary_op)
except Exception, e: # pylint: disable=broad-except
tf.logging.error("Evaluation failed.")
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
def run():
"""Runs evaluation in a loop, and logs summaries to TensorBoard."""
# Create the evaluation directory if it doesn't exist.
if not tf.gfile.IsDirectory(eval_dir):
tf.logging.info("Creating eval directory: %s", eval_dir)
tf.gfile.MakeDirs(eval_dir)
g = tf.Graph()
with g.as_default():
images, input_seqs, target_seqs, input_mask = Build_Inputs(mode, input_file_pattern)
net_image_embeddings = Build_Image_Embeddings(mode, images, train_inception)
net_seq_embeddings = Build_Seq_Embeddings(input_seqs)
_, target_cross_entropy_losses, target_cross_entropy_loss_weights, network = \
Build_Model(mode, net_image_embeddings, net_seq_embeddings, target_seqs, input_mask)
global_step = tf.Variable(
initial_value=0,
dtype=tf.int32,
name="global_step",
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.VARIABLES])
# Create the Saver to restore model Variables.
saver = tf.train.Saver()
# Create the summary operation and the summary writer.
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(eval_dir)
g.finalize()
# Run a new evaluation run every eval_interval_secs.
while True:
start = time.time()
tf.logging.info("Starting evaluation at " + time.strftime(
"%Y-%m-%d-%H:%M:%S", time.localtime()))
run_once(global_step, target_cross_entropy_losses,
target_cross_entropy_loss_weights,
saver, summary_writer,
summary_op)
time_to_next_eval = start + eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
def main(_):
# Configuration.
num_unrolls = FLAGS.num_steps
if FLAGS.seed:
tf.set_random_seed(FLAGS.seed)
# Problem.
problem, net_config, net_assignments = util.get_config(FLAGS.problem,
FLAGS.path)
# Optimizer setup.
if FLAGS.optimizer == "Adam":
cost_op = problem()
problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
problem_reset = tf.variables_initializer(problem_vars)
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
update = optimizer.minimize(cost_op)
reset = [problem_reset, optimizer_reset]
elif FLAGS.optimizer == "L2L":
if FLAGS.path is None:
logging.warning("Evaluating untrained L2L optimizer")
optimizer = meta.MetaOptimizer(**net_config)
meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments)
_, update, reset, cost_op, _ = meta_loss
else:
raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))
with ms.MonitoredSession() as sess:
# Prevent accidental changes to the graph.
tf.get_default_graph().finalize()
total_time = 0
total_cost = 0
for _ in xrange(FLAGS.num_epochs):
# Training.
time, cost = util.run_epoch(sess, cost_op, [update], reset,
num_unrolls)
total_time += time
total_cost += cost
# Results.
util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
total_time, FLAGS.num_epochs)
def _build_input_fn(input_file_pattern, batch_size, mode):
"""Build input function.
Args:
input_file_pattern: The file patter for examples
batch_size: Batch size
mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
Returns:
Tuple, dictionary of feature column name to tensor and labels.
"""
def _input_fn():
"""Supplies the input to the model.
Returns:
A tuple consisting of 1) a dictionary of tensors whose keys are
the feature names, and 2) a tensor of target labels if the mode
is not INFER (and None, otherwise).
"""
logging.info("Reading files from %s", input_file_pattern)
input_files = sorted(list(tf.gfile.Glob(input_file_pattern)))
logging.info("Reading files from %s", input_files)
include_target_column = (mode != tf.contrib.learn.ModeKeys.INFER)
features_spec = tf.contrib.layers.create_feature_spec_for_parsing(
feature_columns=_get_feature_columns(include_target_column))
if FLAGS.use_gzip:
def gzip_reader():
return tf.TFRecordReader(
options=tf.python_io.TFRecordOptions(
compression_type=TFRecordCompressionType.GZIP))
reader_fn = gzip_reader
else:
reader_fn = tf.TFRecordReader
features = tf.contrib.learn.io.read_batch_features(
file_pattern=input_files,
batch_size=batch_size,
queue_capacity=3*batch_size,
randomize_input=mode == tf.contrib.learn.ModeKeys.TRAIN,
feature_queue_capacity=FLAGS.feature_queue_capacity,
reader=reader_fn,
features=features_spec)
target = None
if include_target_column:
target = features.pop(FLAGS.target_field)
return features, target
return _input_fn