train.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:deep-learning 作者: ljanyst 项目源码 文件源码
def gen_batches(data, n_seqs, n_steps):
    """Create a generator that returns batches of size n_seqs x n_steps."""

    characters_per_batch = n_seqs * n_steps
    n_batches = len(data) // characters_per_batch

    # Keep only enough characters to make full batches
    data = data[:n_batches*characters_per_batch]
    data = data.reshape([n_seqs, -1])

    for n in range(0, data.shape[1], n_steps):
        x = data[:, n:n+n_steps]
        y = np.zeros_like(x)
        y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
        yield x, y

#-------------------------------------------------------------------------------
# Parse commandline
#-------------------------------------------------------------------------------
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号