def get_decode_batch(self, data, bucket_id):
"""Get sequential batch
"""
encoder_size, decoder_size = self.buckets[bucket_id]
encoder_inputs, decoder_inputs = [], []
this_batch_size = len(data[bucket_id])
## SHUBHAM - seq_len initialized
seq_len = []
# Get a random batch of encoder and decoder inputs from data,
# pad them if needed, reverse encoder inputs and add GO to decoder.
for sample in data[bucket_id]:
encoder_input, decoder_input = sample
## SHUBHAM - Append Entries
seq_len.append(len(encoder_input))
# Encoder inputs are padded and then reversed.
encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
## SHUBHAM - reversing just the input
encoder_inputs.append(list(reversed(encoder_input)) + encoder_pad)
# Decoder inputs get an extra "GO" symbol, and are padded then.
decoder_pad_size = decoder_size - len(decoder_input) - 1
decoder_inputs.append([data_utils.GO_ID] + decoder_input +
[data_utils.PAD_ID] * decoder_pad_size)
# Now we create batch-major vectors from the data selected above.
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# Batch encoder inputs are just re-indexed encoder_inputs.
for length_idx in xrange(encoder_size):
batch_encoder_inputs.append(np.array([encoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(this_batch_size)], dtype=np.int32))
# Batch decoder inputs are re-indexed decoder_inputs, we create weights.
for length_idx in xrange(decoder_size):
batch_decoder_inputs.append(np.array([decoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(this_batch_size)], dtype=np.int32))
# Create target_weights to be 0 for targets that are padding.
batch_weight = np.ones(this_batch_size, dtype=np.float32)
for batch_idx in xrange(this_batch_size):
# We set weight to 0 if the corresponding target is a PAD symbol.
# The corresponding target is decoder_input shifted by 1 forward.
if length_idx < decoder_size - 1:
target = decoder_inputs[batch_idx][length_idx + 1]
if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
batch_weight[batch_idx] = 0.0
batch_weights.append(batch_weight)
## SHUBHAM - seq_len as nparray and then passing it as well
seq_len = np.asarray(seq_len, dtype=np.int64)
return batch_encoder_inputs, batch_decoder_inputs, batch_weights, seq_len
评论列表
文章目录