def gen_samples(self, num_samples):
"""Generate sample for ML near the snake."""
points = [] # the sample points
labels = [] # the labels
whichs = [] # the corresponding node for the sample
deri_g = [] # the partial derivative to g
deri_T = [] # the partial derivative to T
counter = 0
assert num_samples % self.length == 0
for i, (v, n) in enumerate(zip(self.vertices, self.normals())):
for d in np.linspace(-1, 1, num_samples / self.length):
# geometry
r = 2 * self.widths[i] * d
s = v + r * n
l = array([0.5 * (1. - np.tanh(d)),
0.5 * (1. + np.tanh(d))])
points.append(s)
labels.append(l)
whichs.append(i)
# cal derivatives
cosh_d = np.cosh(d)
deri_g.append(1 / (4 * self.widths[i] * cosh_d * cosh_d))
deri_T.append(d / (2 * self.widths[i] * cosh_d * cosh_d))
counter += 1
if counter == num_samples:
return array(points), array(labels), array(whichs), array(deri_g), array(deri_T)
评论列表
文章目录