def get_batch(self, bucket_dbs, bucket_id, data):
encoder_size, decoder_size = self.buckets[bucket_id]
# bucket_db = bucket_dbs[bucket_id]
encoder_inputs, decoder_inputs = [], []
for encoder_input, decoder_input in data:
# encoder_input, decoder_input = random.choice(data[bucket_id])
# encoder_input, decoder_input = bucket_db.random()
encoder_input = data_utils.sentence_indice(encoder_input)
decoder_input = data_utils.sentence_indice(decoder_input)
# Encoder
encoder_pad = [data_utils.PAD_ID] * (
encoder_size - len(encoder_input)
)
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
# Decoder
decoder_pad_size = decoder_size - len(decoder_input) - 2
decoder_inputs.append(
[data_utils.GO_ID] + decoder_input +
[data_utils.EOS_ID] +
[data_utils.PAD_ID] * decoder_pad_size
)
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# batch encoder
for i in range(encoder_size):
batch_encoder_inputs.append(np.array(
[encoder_inputs[j][i] for j in range(self.batch_size)],
dtype=np.int32
))
# batch decoder
for i in range(decoder_size):
batch_decoder_inputs.append(np.array(
[decoder_inputs[j][i] for j in range(self.batch_size)],
dtype=np.int32
))
batch_weight = np.ones(self.batch_size, dtype=np.float32)
for j in range(self.batch_size):
if i < decoder_size - 1:
target = decoder_inputs[j][i + 1]
if i == decoder_size - 1 or target == data_utils.PAD_ID:
batch_weight[j] = 0.0
batch_weights.append(batch_weight)
return batch_encoder_inputs, batch_decoder_inputs, batch_weights
评论列表
文章目录