net.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:bigan 作者: jeffdonahue 项目源码 文件源码
def get_output(self, h, nout=None, stddev=None,
                   reparameterize=reparam, exp_reparam=exp_reparam):
        h, h_shape, h_max = h.value, h.shape, h.index_max
        nin = np.prod(h_shape[1:], dtype=np.int) if (h_max is None) else h_max
        out_shape_specified = isinstance(nout, tuple)
        if out_shape_specified:
            out_shape = nout
        else:
            assert isinstance(nout, int)
            out_shape = nout,
        nout = np.prod(out_shape)
        nin_axis = [0]
        W = self.weights((nin, nout), stddev=stddev,
            reparameterize=reparameterize, nin_axis=nin_axis,
            exp_reparam=exp_reparam)
        if h_max is None:
            if h.ndim > 2:
                h = T.flatten(h, 2)
            out = T.dot(h, W)
        else:
            assert nin >= 1, 'FC: h.index_max must be >= 1; was: %s' % (nin,)
            assert h.ndim == 1
            out = W[h]
        return Output(out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号