helpers.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def set_shape(tensor, shape):
  """ This function will filling the missing shape information
  of given tensor
  """
  if not is_tensor(tensor):
    raise ValueError('tensor must be instance of `Tensor`.')
  # ====== Test ====== #
  ndims = tensor.get_shape().ndims
  shape = as_tuple(shape)
  if ndims != len(shape):
    raise ValueError("The tensor has %d dimensions, but the given shape "
                     "has %d dimension." % (ndims, len(shape)))
  # ====== DO it ====== #
  old_shape = tensor.get_shape()
  new_shape = []
  for old, new in zip(old_shape, shape):
    old_value = old.value
    if isinstance(new, tf.Dimension):
      new = new.value
    # matching old and new values
    if old_value is not None and new is not None:
      if old_value != new:
        raise ValueError("Known shape information mismatch, from tensorflow"
            ":%s, and given shape:%s." %
            (str(old_shape.as_list()), str(shape)))
      else:
        new_shape.append(old_value)
    elif old_value is None and new is not None:
      new_shape.append(new)
    elif old_value is not None and new is None:
      new_shape.append(old_value)
    elif old is None and new is None:
      new_shape.append(old)
    else:
      new_shape.append(None)
  tensor.set_shape(new_shape)
  return tensor


# ===========================================================================
# VALUE MANIPULATION
# ===========================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号