main.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:punctuator2 作者: ottokart 项目源码 文件源码
def get_minibatch(file_name, batch_size, shuffle, with_pauses=False):

    dataset = data.load(file_name)

    if shuffle:
        np.random.shuffle(dataset)

    X_batch = []
    Y_batch = []
    if with_pauses:
        P_batch = []

    if len(dataset) < batch_size:
        print "WARNING: Not enough samples in '%s'. Reduce mini-batch size to %d or use a dataset with at least %d words." % (
            file_name,
            len(dataset),
            MINIBATCH_SIZE * data.MAX_SEQUENCE_LEN)

    for subsequence in dataset:

        X_batch.append(subsequence[0])
        Y_batch.append(subsequence[1])
        if with_pauses:
            P_batch.append(subsequence[2])

        if len(X_batch) == batch_size:

            # Transpose, because the model assumes the first axis is time
            X = np.array(X_batch, dtype=np.int32).T
            Y = np.array(Y_batch, dtype=np.int32).T
            if with_pauses:
                P = np.array(P_batch, dtype=theano.config.floatX).T

            if with_pauses:
                yield X, Y, P
            else:
                yield X, Y

            X_batch = []
            Y_batch = []
            if with_pauses:
                P_batch = []
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号