parse_s2s_att.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:parse_seq2seq 作者: avikdelta 项目源码 文件源码
def create_batches(data):
    print("generating batches...")
    batches = [[] for _ in _buckets]
    for bucket_id in xrange(len(_buckets)):
        data_bucket = data[bucket_id]
        encoder_size, decoder_size = _buckets[bucket_id]

        # shuffle the data
        data_permute = np.random.permutation(len(data_bucket))

        num_batches = math.ceil(len(data_bucket)/FLAGS.batch_size)
        for b_idx in xrange(num_batches):
            encoder_inputs, decoder_inputs = [], []
            for i in xrange(FLAGS.batch_size):
                data_idx = data_permute[(b_idx*FLAGS.batch_size+i) % len(data_bucket)]
                encoder_input, decoder_input = data_bucket[data_idx]

                # Encoder inputs are padded and then reversed.
                encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
                encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))

                # Decoder inputs get an extra "GO" symbol, and are padded then.
                decoder_pad_size = decoder_size - len(decoder_input) - 1
                decoder_inputs.append([data_utils.GO_ID] + decoder_input + [data_utils.PAD_ID] * decoder_pad_size)

            # Now we create batch-major vectors from the data selected above.
            batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []

            # Batch encoder inputs are just re-indexed encoder_inputs.
            for length_idx in xrange(encoder_size):
                batch_encoder_inputs.append(np.array([encoder_inputs[batch_idx][length_idx] 
                                            for batch_idx in xrange(FLAGS.batch_size)], dtype=np.int32))

            # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
            for length_idx in xrange(decoder_size):
                batch_decoder_inputs.append(np.array([decoder_inputs[batch_idx][length_idx]
                                            for batch_idx in xrange(FLAGS.batch_size)], dtype=np.int32))

                # Create target_weights to be 0 for targets that are padding.
                batch_weight = np.ones(FLAGS.batch_size, dtype=np.float32)
                for batch_idx in xrange(FLAGS.batch_size):
                    # We set weight to 0 if the corresponding target is a PAD symbol.
                    # The corresponding target is decoder_input shifted by 1 forward.
                    if length_idx < decoder_size - 1:
                        target = decoder_inputs[batch_idx][length_idx + 1]
                    if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
                        batch_weight[batch_idx] = 0.0

                batch_weights.append(batch_weight)

            batches[bucket_id].append((batch_encoder_inputs, batch_decoder_inputs, batch_weights))

    return batches

#-----------------------------------------------------
# main training function
#-----------------------------------------------------
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号