net.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:bigan 作者: jeffdonahue 项目源码 文件源码
def __init__(self, value, shape=None, index_max=None):
        """
        value: A Theano Tensor, shared variable, or constant value.
        shape: May be None (default) if non-symbolic shape is accessible by
             value.get_value().shape (as in a Theano shared variable --
             tried first) or by value.shape (as in a NumPy array).
             Otherwise (e.g., if value is a symbolic Theano tensor), shape
             should be specified as an iterable of ints, where some may be -1
             for don't cares (e.g., batch size).
        index_max: If value is integer-typed, index_max may be used
             to specify its maximum value.  e.g., a batch of N one-hot vectors,
             each representing a word in a 500 word vocabulary, could be
             specified with an integer-typed Tensor with values in
             [0, 1, ..., 499], and index_max=500.
        """
        if isinstance(value, Output):
            raise TypeError("value may not be an Output")
        self.value = value
        if shape is None:
            try:
                shape = value.get_value().shape
            except AttributeError:
                try:
                    shape = value.shape
                    if isinstance(shape, theano.Variable):
                        shape = None
                except AttributeError:
                    pass
        if shape is not None:
            for s in list(shape) + ([] if (index_max is None) else [index_max]):
                assert isinstance(s, int)
                assert s >= 0
            shape = tuple(shape)
            assert len(shape) == value.ndim
        self.shape = shape
        if index_max is not None:
            assert isinstance(value, int) or str(value.dtype).startswith('int'), \
                ('if index_max is given, value must be integer-typed; '
                 'was: %s' % value.dtype)
            assert index_max == int(index_max)
            index_max = int(index_max)
            if index_max < 0:
                raise ValueError('index_max must be non-negative')
        self.index_max = index_max
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号