nets.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:2048 作者: vhalis 项目源码 文件源码
def __init__(self,
                 hidden_sizes=DEFAULT_HIDDEN_SIZES,
                 weights=DEFAULT_WEIGHTS,
                 inputs=DEFAULT_INPUTS,
                 outputs=DEFAULT_OUTPUTS,
                 weight_spread=None,
                 weight_middle=None):
        """
        @hidden_sizes: An iterable of integers that describe the sizes of the
                       hidden layers of the Net.
        @weights: May be a function that returns arrays to use as weights.
                  If so, must take an iterable of sizes to create weights for
                  and must return the same data as described below.
                  Else it must be numpy.ndarrays of dtype=float and proper sizes
                  in the proper order provided in a sliceable.
        @inputs: The integer number of inputs.
        @outputs: The integer number of outputs.
        """
        if not isinstance(inputs, int) or not isinstance(outputs, int):
            raise ValueError('Number of inputs and outputs must be integers')
        if (not hasattr(hidden_sizes, '__iter__')
                or not all(isinstance(i, int) for i in hidden_sizes)):
            raise ValueError('Sizes of hidden layers must be integers'
                             ' provided in an iterable')

        self.sizes = tuple(chain((inputs,),
                                 hidden_sizes,
                                 (outputs,)))
        if weights and callable(weights):
            weights = weights(self.sizes)
        if (weights and (not hasattr(weights, '__getslice__')
                         or not all(isinstance(arr, numpy.ndarray)
                                     for arr in weights)
                         or not all(arr.dtype == float for arr in weights))):
            raise ValueError('Weights of hidden layers must be numpy.ndarrays'
                             ' with dtype=float provided in a sliceable')

        self.inputs = inputs
        self.outputs = outputs
        self.weights = weights or Net.random_weights(self.sizes,
                                                      weight_spread,
                                                      weight_middle)
        for idx, w in enumerate(self.weights):
            assert(w.shape == (self.sizes[idx], self.sizes[idx+1]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号