def get_output(self, h, nout=None, stddev=None,
reparameterize=reparam, exp_reparam=exp_reparam):
h, h_shape, h_max = h.value, h.shape, h.index_max
nin = np.prod(h_shape[1:], dtype=np.int) if (h_max is None) else h_max
out_shape_specified = isinstance(nout, tuple)
if out_shape_specified:
out_shape = nout
else:
assert isinstance(nout, int)
out_shape = nout,
nout = np.prod(out_shape)
nin_axis = [0]
W = self.weights((nin, nout), stddev=stddev,
reparameterize=reparameterize, nin_axis=nin_axis,
exp_reparam=exp_reparam)
if h_max is None:
if h.ndim > 2:
h = T.flatten(h, 2)
out = T.dot(h, W)
else:
assert nin >= 1, 'FC: h.index_max must be >= 1; was: %s' % (nin,)
assert h.ndim == 1
out = W[h]
return Output(out)
评论列表
文章目录