helpers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def __init__(self, inputs, outputs, updates=[], defaults={},
               training=None):
    self.training = training
    # ====== validate input ====== #
    if isinstance(inputs, Mapping):
      self.inputs_name = inputs.keys()
      inputs = inputs.values()
    elif not isinstance(inputs, (tuple, list)):
      inputs = [inputs]
    self.inputs = flatten_list(inputs, level=None)
    if not hasattr(self, 'inputs_name'):
      self.inputs_name = [i.name.split(':')[0] for i in self.inputs]
    # ====== defaults ====== #
    defaults = dict(defaults)
    self.defaults = defaults
    # ====== validate outputs ====== #
    return_list = True
    if not isinstance(outputs, (tuple, list)):
      outputs = (outputs,)
      return_list = False
    self.outputs = flatten_list(list(outputs), level=None)
    self.return_list = return_list
    # ====== validate updates ====== #
    if isinstance(updates, Mapping):
      updates = updates.items()
    with tf.control_dependencies(self.outputs):
      # create updates ops
      if not isinstance(updates, tf.Operation):
        updates_ops = []
        for update in updates:
          if isinstance(update, (tuple, list)):
            p, new_p = update
            updates_ops.append(tf.assign(p, new_p))
          else: # assumed already an assign op
            updates_ops.append(update)
        self.updates_ops = tf.group(*updates_ops)
      else: # already an tensorflow Ops
        self.updates_ops = updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号