gendata.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __init__(self, pvals, shape, seed=0):
        if isinstance(shape, numbers.Integral):
            shape = (shape,)
        self.__rng = np.random.RandomState(seed)
        self.nclasses = len(pvals)
        self.shape = shape
        self.size = 1
        for s in shape:
            self.size = self.size * s
        self.As = self.__rng.uniform(-1, 1, (self.size, self.size, self.nclasses,))
        self.bs = self.__rng.uniform(-1, 1, (self.size, self.nclasses,))

        self.accum = []
        s = 0
        for pval in pvals:
            s = s + pval
            self.accum.append(s)
        self.accum[-1] = 2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号