def get_minibatch(file_name, batch_size, shuffle, with_pauses=False):
dataset = data.load(file_name)
if shuffle:
np.random.shuffle(dataset)
X_batch = []
Y_batch = []
if with_pauses:
P_batch = []
if len(dataset) < batch_size:
print "WARNING: Not enough samples in '%s'. Reduce mini-batch size to %d or use a dataset with at least %d words." % (
file_name,
len(dataset),
MINIBATCH_SIZE * data.MAX_SEQUENCE_LEN)
for subsequence in dataset:
X_batch.append(subsequence[0])
Y_batch.append(subsequence[1])
if with_pauses:
P_batch.append(subsequence[2])
if len(X_batch) == batch_size:
# Transpose, because the model assumes the first axis is time
X = np.array(X_batch, dtype=np.int32).T
Y = np.array(Y_batch, dtype=np.int32).T
if with_pauses:
P = np.array(P_batch, dtype=theano.config.floatX).T
if with_pauses:
yield X, Y, P
else:
yield X, Y
X_batch = []
Y_batch = []
if with_pauses:
P_batch = []
评论列表
文章目录