basic.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def flatten(x, outdim=1):
    """
    Reshapes the variable x by keeping
    the first outdim-1 dimension size(s) of x the same,
    and making the last dimension size of x equal to
    the multiplication of its remaining dimension size(s).

    Parameters
    ----------
        x : theano.tensor.var.TensorVariable
            the variable that should be reshaped.

        outdim : int
            the number of dimensions of the returned variable

    Returns
    -------
    theano.tensor.var.TensorVariable
        the flattend variable with dimensionality of outdim
    """
    # Any input variable can be flattened to have outdim of 1,
    # even if it's a scalar. Otherwise, outdim must be positive
    # and smaller than x.ndim.
    if outdim < 1 or (outdim > 1 and outdim > x.ndim):
        raise ValueError('outdim %s out of bound [1, %d)'
                         % (outdim, x.ndim + 1))

    if outdim > 1:
        dims = tuple(x.shape[:outdim - 1]) + (-1,)
    else:
        dims = (-1,)
    x_reshaped = x.reshape(dims)
    bcast_kept_dims = x.broadcastable[:outdim - 1]
    bcast_new_dim = python_all(x.broadcastable[outdim - 1:])
    broadcastable = bcast_kept_dims + (bcast_new_dim,)
    x_reshaped = theano.tensor.addbroadcast(
        x_reshaped, *filter(lambda i: broadcastable[i], range(outdim)))
    return x_reshaped


# class TileGrad(Op):
#     """
#     Calculates the gradient of the Tile Op.
#     """
#     # this is so weird, I can't think of how to make this a general thing.
#     def make_node(self, x, reps, g_out):
#         return gof.Apply(self, [x, reps, g_out], [x.type()])
#
#     def perform(self, node, inp, out):
#         x, reps, g_out = inp
#         gx, = out
#         xsh = x.shape
#         if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
#             gx[0] = numpy.sum(g_out, axis=0)
#         else:
#             raise NotImplementedError('x.shape, reps combination not '
#                                       'supported', (x.shape, reps))
#
# tilegrad = TileGrad()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号