recurrent_network_2a.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:seqrnns 作者: x75 项目源码 文件源码
def gen_data2(k = 0, min_length=50, max_length=55, n_batch=5, freq = 2.):
    print "k", k
    # t = np.linspace(0, 2*np.pi, n_batch)
    t = np.linspace(k*n_batch, (k+1)*n_batch+1, n_batch+1, endpoint=False)
    # print "t.shape", t.shape, t, t[:-1], t[1:]
    # freq = 1.
    Xtmp = np.sin(t[:-1] * freq / (2*np.pi))
    print Xtmp.shape
    # Xtmp = [np.sin(t[i:i+max_length]) for i in range(n_batch)]
    # print len(Xtmp)
    X = np.array(Xtmp).reshape((n_batch, input_size))
    # X = 
    # y = np.zeros((n_batch,))
    y = np.sin(t[1:] * freq / (2 * np.pi)).reshape((n_batch, output_size))
    # print X,y
    # print X.shape, y.shape
    # for i in range(batch_size):
    #     pl.subplot(211)
    #     pl.plot(X[i,:,0])
    #     # pl.subplot(312)
    #     # pl.plot(X[i,:,1])
    # pl.subplot(212)
    # pl.plot(y)
    # pl.show()

    return (X,y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号