gst_rnn_model.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow 作者: liuyuemaicha 项目源码 文件源码
def get_batch(self, train_data, bucket_id):
        encoder_size, decoder_size = self.buckets[bucket_id]
        encoder_inputs, decoder_inputs = [], []
        batch_source_encoder, batch_source_decoder = [], []

        #print("bucket_id: ", bucket_id)
        for batch_i in xrange(self.batch_size):
            encoder_input, decoder_input = random.choice(train_data[bucket_id])

            batch_source_encoder.append(encoder_input)
            batch_source_decoder.append(decoder_input)

            #print("encoder_input: ", encoder_input)
            encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
            encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
            #print("encoder_input pad: ", list(reversed(encoder_input + encoder_pad)))

            #print("decoder_input: ", decoder_input)
            decoder_pad_size = decoder_size - len(decoder_input) - 1
            decoder_inputs.append([data_utils.GO_ID] + decoder_input +
                                  [data_utils.PAD_ID] * decoder_pad_size)
            #print("decoder_pad: ",[data_utils.GO_ID] + decoder_input + [data_utils.PAD_ID] * decoder_pad_size)

        batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []

        for length_idx in xrange(encoder_size):
            batch_encoder_inputs.append(
                np.array([encoder_inputs[batch_idx][length_idx]
                          for batch_idx in xrange(self.batch_size)], dtype=np.int32))

        for length_idx in xrange(decoder_size):
            batch_decoder_inputs.append(
                np.array([decoder_inputs[batch_idx][length_idx]
                          for batch_idx in xrange(self.batch_size)], dtype=np.int32))

            batch_weight = np.ones(self.batch_size, dtype=np.float32)
            for batch_idx in xrange(self.batch_size):
                # We set weight to 0 if the corresponding target is a PAD symbol.
                # The corresponding target is decoder_input shifted by 1 forward.
                if length_idx < decoder_size - 1:
                    target = decoder_inputs[batch_idx][length_idx + 1]
                if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
                    batch_weight[batch_idx] = 0.0
            batch_weights.append(batch_weight)

        return batch_encoder_inputs, batch_decoder_inputs, batch_weights, batch_source_encoder, batch_source_decoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号