loop.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:SSD-Keras_Tensorflow 作者: jedol 项目源码 文件源码
def fit_generator(sess, num_iter, operations=[], batch_generator=None, inputs=[], outputs={},
                  static_inputs={}, print_interval=100):
    if not isinstance(operations, list):
        if isinstance(operations, tuple):
            operations = list(operations)
        assert isinstance(operations, tf.Operation)
        operations = [operations]
    if not isinstance(inputs, list) and not isinstance(inputs, tuple):
        assert isinstance(inputs, tf.Tensor)
        inputs = [inputs]
    if not isinstance(outputs, dict):
        assert isinstance(outputs, tf.Tensor)
        outputs = {'loss':outputs}

    output_names = outputs.keys()
    output_tensors = outputs.values()

    tic = ti.default_timer()
    for step in xrange(1,num_iter+1):
        feed_dict = dict()
        if batch_generator is not None:
            feed_dict.update(dict(zip(inputs,batch_generator.next())))
        feed_dict.update(static_inputs)

        if step % print_interval == 0:
            outputs = sess.run(operations+output_tensors, feed_dict=feed_dict)[len(operations):]

            toc = ti.default_timer()
            eta = (toc-tic)/step*(num_iter-step)
            log = '[Step: {}/{} ETA: {:.0f}s]'.format(step, num_iter, eta)

            for output_name, output in zip(output_names, outputs):
                log += ' {}: {:.4f}'.format(output_name, output)

            print log
        else:
            _ = sess.run(operations, feed_dict=feed_dict)
    toc = ti.default_timer()
    log = '[Step: {}/{} ETA: {:.0f}s]'.format(step, num_iter, toc-tic)
    print log
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号