def fit_generator(sess, num_iter, operations=[], batch_generator=None, inputs=[], outputs={},
static_inputs={}, print_interval=100):
if not isinstance(operations, list):
if isinstance(operations, tuple):
operations = list(operations)
assert isinstance(operations, tf.Operation)
operations = [operations]
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
assert isinstance(inputs, tf.Tensor)
inputs = [inputs]
if not isinstance(outputs, dict):
assert isinstance(outputs, tf.Tensor)
outputs = {'loss':outputs}
output_names = outputs.keys()
output_tensors = outputs.values()
tic = ti.default_timer()
for step in xrange(1,num_iter+1):
feed_dict = dict()
if batch_generator is not None:
feed_dict.update(dict(zip(inputs,batch_generator.next())))
feed_dict.update(static_inputs)
if step % print_interval == 0:
outputs = sess.run(operations+output_tensors, feed_dict=feed_dict)[len(operations):]
toc = ti.default_timer()
eta = (toc-tic)/step*(num_iter-step)
log = '[Step: {}/{} ETA: {:.0f}s]'.format(step, num_iter, eta)
for output_name, output in zip(output_names, outputs):
log += ' {}: {:.4f}'.format(output_name, output)
print log
else:
_ = sess.run(operations, feed_dict=feed_dict)
toc = ti.default_timer()
log = '[Step: {}/{} ETA: {:.0f}s]'.format(step, num_iter, toc-tic)
print log
评论列表
文章目录