def __init__(self, value, shape=None, index_max=None):
"""
value: A Theano Tensor, shared variable, or constant value.
shape: May be None (default) if non-symbolic shape is accessible by
value.get_value().shape (as in a Theano shared variable --
tried first) or by value.shape (as in a NumPy array).
Otherwise (e.g., if value is a symbolic Theano tensor), shape
should be specified as an iterable of ints, where some may be -1
for don't cares (e.g., batch size).
index_max: If value is integer-typed, index_max may be used
to specify its maximum value. e.g., a batch of N one-hot vectors,
each representing a word in a 500 word vocabulary, could be
specified with an integer-typed Tensor with values in
[0, 1, ..., 499], and index_max=500.
"""
if isinstance(value, Output):
raise TypeError("value may not be an Output")
self.value = value
if shape is None:
try:
shape = value.get_value().shape
except AttributeError:
try:
shape = value.shape
if isinstance(shape, theano.Variable):
shape = None
except AttributeError:
pass
if shape is not None:
for s in list(shape) + ([] if (index_max is None) else [index_max]):
assert isinstance(s, int)
assert s >= 0
shape = tuple(shape)
assert len(shape) == value.ndim
self.shape = shape
if index_max is not None:
assert isinstance(value, int) or str(value.dtype).startswith('int'), \
('if index_max is given, value must be integer-typed; '
'was: %s' % value.dtype)
assert index_max == int(index_max)
index_max = int(index_max)
if index_max < 0:
raise ValueError('index_max must be non-negative')
self.index_max = index_max
评论列表
文章目录