def get_batch(self, train_data, bucket_id):
encoder_size, decoder_size = self.buckets[bucket_id]
encoder_inputs, decoder_inputs = [], []
batch_source_encoder, batch_source_decoder = [], []
#print("bucket_id: ", bucket_id)
for batch_i in xrange(self.batch_size):
encoder_input, decoder_input = random.choice(train_data[bucket_id])
batch_source_encoder.append(encoder_input)
batch_source_decoder.append(decoder_input)
#print("encoder_input: ", encoder_input)
encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
#print("encoder_input pad: ", list(reversed(encoder_input + encoder_pad)))
#print("decoder_input: ", decoder_input)
decoder_pad_size = decoder_size - len(decoder_input) - 1
decoder_inputs.append([data_utils.GO_ID] + decoder_input +
[data_utils.PAD_ID] * decoder_pad_size)
#print("decoder_pad: ",[data_utils.GO_ID] + decoder_input + [data_utils.PAD_ID] * decoder_pad_size)
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
for length_idx in xrange(encoder_size):
batch_encoder_inputs.append(
np.array([encoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
for length_idx in xrange(decoder_size):
batch_decoder_inputs.append(
np.array([decoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
batch_weight = np.ones(self.batch_size, dtype=np.float32)
for batch_idx in xrange(self.batch_size):
# We set weight to 0 if the corresponding target is a PAD symbol.
# The corresponding target is decoder_input shifted by 1 forward.
if length_idx < decoder_size - 1:
target = decoder_inputs[batch_idx][length_idx + 1]
if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
batch_weight[batch_idx] = 0.0
batch_weights.append(batch_weight)
return batch_encoder_inputs, batch_decoder_inputs, batch_weights, batch_source_encoder, batch_source_decoder
gst_rnn_model.py 文件源码
python
阅读 15
收藏 0
点赞 0
评论 0
评论列表
文章目录