def get_batch(self, data, bucket_id):
"""Get batches
"""
this_batch_size = len(data[bucket_id])
encoder_size, decoder_size = self.buckets[bucket_id]
text_encoder_inputs, speech_encoder_inputs, decoder_inputs = [], [], []
seq_len = []
for sample in data[bucket_id]:
text_encoder_input, decoder_input, speech_encoder_input = sample
seq_len.append(len(text_encoder_input))
# Encoder inputs are padded and then reversed.
encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(text_encoder_input))
text_encoder_inputs.append(list(reversed(text_encoder_input)) + encoder_pad)
# do the same for speech encoder inputs: reverse sequence
speech_encoder_inputs.append(np.fliplr(speech_encoder_input).T)
# Decoder inputs get an extra "GO" symbol, and are padded then.
decoder_pad_size = decoder_size - len(decoder_input) - 1
decoder_inputs.append([data_utils.GO_ID] + decoder_input +
[data_utils.PAD_ID] * decoder_pad_size)
# Now we create batch-major vectors from the data selected above.
batch_text_encoder_inputs, batch_speech_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [], []
# Batch encoder inputs are just re-indexed encoder_inputs.
for length_idx in xrange(encoder_size):
batch_text_encoder_inputs.append(
np.array([text_encoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(this_batch_size)], dtype=np.int32))
for length_idx in xrange(encoder_size * spscale):
batch_speech_encoder_inputs.append([speech_encoder_inputs[batch_idx][length_idx, :]
for batch_idx in xrange(this_batch_size)])
# Batch decoder inputs are re-indexed decoder_inputs, we create weights.
for length_idx in xrange(decoder_size):
batch_decoder_inputs.append(
np.array([decoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(this_batch_size)], dtype=np.int32))
# Create target_weights to be 0 for targets that are padding.
batch_weight = np.ones(this_batch_size, dtype=np.float32)
for batch_idx in xrange(this_batch_size):
# We set weight to 0 if the corresponding target is a PAD symbol.
# The corresponding target is decoder_input shifted by 1 forward.
if length_idx < decoder_size - 1:
target = decoder_inputs[batch_idx][length_idx + 1]
if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
batch_weight[batch_idx] = 0.0
batch_weights.append(batch_weight)
seq_len = np.asarray(seq_len, dtype=np.int64)
return batch_text_encoder_inputs, batch_speech_encoder_inputs, batch_decoder_inputs, batch_weights, seq_len
评论列表
文章目录