def generate_batch(s_sents, s_word2index, t_sents, t_word2index,
batch_size, maxlen):
while True:
# shuffle the input
indices = np.random.permutation(np.arange(len(s_sents)))
ss_sents = [s_sents[ix] for ix in indices]
ts_sents = [t_sents[ix] for ix in indices]
# convert to word indices
si_sents = [[get_or_else(s_word2index, word, s_word2index["UNK"])
for word in sent]
for sent in ss_sents]
ti_sents = [[t_word2index[word] for word in sent]
for sent in ts_sents]
# inner loop should run for an epoch
num_batches = len(s_sents) // batch_size
for i in range(num_batches):
s_batch = si_sents[i * batch_size : (i + 1) * batch_size]
t_batch = ti_sents[i * batch_size : (i + 1) * batch_size]
sp_batch = sequence.pad_sequences(s_batch, maxlen=maxlen)
tp_batch = sequence.pad_sequences(t_batch, maxlen=maxlen)
tpc_batch = np_utils.to_categorical(tp_batch.reshape(-1, 1),
num_classes=len(t_word2index)).reshape(batch_size,
-1, len(t_word2index))
yield sp_batch, tpc_batch
评论列表
文章目录