def run_targets(sess,
dbinterface,
target_name,
target,
valid_loop,
num_steps,
online_agg_func,
agg_func,
save_intermediate_freq=None,
validation_only=False):
"""TODO: this code resembles train() function, possible want to unify."""
agg_res = None
if save_intermediate_freq is not None:
n0 = len(dbinterface.outrecs)
for _step in tqdm.trange(num_steps, desc=target_name):
if valid_loop is not None:
res = valid_loop(sess, target)
else:
res = sess.run(target)
assert hasattr(res, 'keys'), 'result must be a dictionary'
if save_intermediate_freq is not None and _step % save_intermediate_freq == 0:
dbinterface.save(valid_res={target_name: res},
step=_step,
validation_only=validation_only)
agg_res = online_agg_func(agg_res, res, _step)
result = agg_func(agg_res)
if save_intermediate_freq is not None:
dbinterface.sync_with_host()
n1 = len(dbinterface.outrecs)
result['intermediate_steps'] = dbinterface.outrecs[n0: n1]
return result
评论列表
文章目录