def gen_data2(k = 0, min_length=50, max_length=55, n_batch=5, freq = 2.):
print "k", k
# t = np.linspace(0, 2*np.pi, n_batch)
t = np.linspace(k*n_batch, (k+1)*n_batch+1, n_batch+1, endpoint=False)
# print "t.shape", t.shape, t, t[:-1], t[1:]
# freq = 1.
Xtmp = np.sin(t[:-1] * freq / (2*np.pi))
print Xtmp.shape
# Xtmp = [np.sin(t[i:i+max_length]) for i in range(n_batch)]
# print len(Xtmp)
X = np.array(Xtmp).reshape((n_batch, input_size))
# X =
# y = np.zeros((n_batch,))
y = np.sin(t[1:] * freq / (2 * np.pi)).reshape((n_batch, output_size))
# print X,y
# print X.shape, y.shape
# for i in range(batch_size):
# pl.subplot(211)
# pl.plot(X[i,:,0])
# # pl.subplot(312)
# # pl.plot(X[i,:,1])
# pl.subplot(212)
# pl.plot(y)
# pl.show()
return (X,y)
评论列表
文章目录