def create_batches(data):
print("generating batches...")
batches = [[] for _ in _buckets]
for bucket_id in xrange(len(_buckets)):
data_bucket = data[bucket_id]
encoder_size, decoder_size = _buckets[bucket_id]
# shuffle the data
data_permute = np.random.permutation(len(data_bucket))
num_batches = math.ceil(len(data_bucket)/FLAGS.batch_size)
for b_idx in xrange(num_batches):
encoder_inputs, decoder_inputs = [], []
for i in xrange(FLAGS.batch_size):
data_idx = data_permute[(b_idx*FLAGS.batch_size+i) % len(data_bucket)]
encoder_input, decoder_input = data_bucket[data_idx]
# Encoder inputs are padded and then reversed.
encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
# Decoder inputs get an extra "GO" symbol, and are padded then.
decoder_pad_size = decoder_size - len(decoder_input) - 1
decoder_inputs.append([data_utils.GO_ID] + decoder_input + [data_utils.PAD_ID] * decoder_pad_size)
# Now we create batch-major vectors from the data selected above.
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# Batch encoder inputs are just re-indexed encoder_inputs.
for length_idx in xrange(encoder_size):
batch_encoder_inputs.append(np.array([encoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(FLAGS.batch_size)], dtype=np.int32))
# Batch decoder inputs are re-indexed decoder_inputs, we create weights.
for length_idx in xrange(decoder_size):
batch_decoder_inputs.append(np.array([decoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(FLAGS.batch_size)], dtype=np.int32))
# Create target_weights to be 0 for targets that are padding.
batch_weight = np.ones(FLAGS.batch_size, dtype=np.float32)
for batch_idx in xrange(FLAGS.batch_size):
# We set weight to 0 if the corresponding target is a PAD symbol.
# The corresponding target is decoder_input shifted by 1 forward.
if length_idx < decoder_size - 1:
target = decoder_inputs[batch_idx][length_idx + 1]
if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
batch_weight[batch_idx] = 0.0
batch_weights.append(batch_weight)
batches[bucket_id].append((batch_encoder_inputs, batch_decoder_inputs, batch_weights))
return batches
#-----------------------------------------------------
# main training function
#-----------------------------------------------------
评论列表
文章目录