snake.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:snake 作者: rhinech 项目源码 文件源码
def gen_samples(self, num_samples):
        """Generate sample for ML near the snake."""

        points = []  # the sample points
        labels = []  # the labels
        whichs = []  # the corresponding node for the sample
        deri_g = []  # the partial derivative to g
        deri_T = []  # the partial derivative to T
        counter = 0
        assert num_samples % self.length == 0
        for i, (v, n) in enumerate(zip(self.vertices, self.normals())):
            for d in np.linspace(-1, 1, num_samples / self.length):
                # geometry
                r = 2 * self.widths[i] * d
                s = v + r * n
                l = array([0.5 * (1. - np.tanh(d)),
                           0.5 * (1. + np.tanh(d))])
                points.append(s)
                labels.append(l)
                whichs.append(i)
                # cal derivatives
                cosh_d = np.cosh(d)
                deri_g.append(1 / (4 * self.widths[i] * cosh_d * cosh_d))
                deri_T.append(d / (2 * self.widths[i] * cosh_d * cosh_d))
                counter += 1
                if counter == num_samples:
                    return array(points), array(labels), array(whichs), array(deri_g), array(deri_T)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号