paren_task.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:GORU-tensorflow 作者: jingli9111 项目源码 文件源码
def paren_data(T, n_data):
    MAX_COUNT = 10
    n_paren = 10
    n_noise = 10

    inputs = (np.random.rand(T, n_data)* (n_paren * 2 + n_noise)).astype(np.int32)
    counts = np.zeros((n_data, n_paren), dtype=np.int32)
    targets = np.zeros((T, n_data, n_paren), dtype = np.int32)
    opening_parens = (np.arange(0, n_paren)*2)[None, :]
    closing_parens = opening_parens + 1
    for i in range(T):
        opened = np.equal(inputs[i, :, None], opening_parens)
        counts = np.minimum(MAX_COUNT, counts + opened)
        closed = np.equal(inputs[i, :, None], closing_parens)
        counts = np.maximum(0, counts - closed)
        targets[i, :, :] = counts


    x = np.transpose(inputs, [1,0])
    y = np.transpose(targets, [1,0,2])

    return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号