seq2seq_model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def get_decode_batch(self, data, bucket_id):
    """Get sequential batch
    """
    encoder_size, decoder_size = self.buckets[bucket_id]
    encoder_inputs, decoder_inputs = [], []
    this_batch_size = len(data[bucket_id])

    ## SHUBHAM - seq_len initialized
    seq_len = []

    # Get a random batch of encoder and decoder inputs from data,
    # pad them if needed, reverse encoder inputs and add GO to decoder.
    for sample in data[bucket_id]:
      encoder_input, decoder_input = sample

      ## SHUBHAM - Append Entries
      seq_len.append(len(encoder_input))

      # Encoder inputs are padded and then reversed.
      encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
      ## SHUBHAM - reversing just the input
      encoder_inputs.append(list(reversed(encoder_input)) + encoder_pad)

      # Decoder inputs get an extra "GO" symbol, and are padded then.
      decoder_pad_size = decoder_size - len(decoder_input) - 1
      decoder_inputs.append([data_utils.GO_ID] + decoder_input +
                            [data_utils.PAD_ID] * decoder_pad_size)

    # Now we create batch-major vectors from the data selected above.
    batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []

    # Batch encoder inputs are just re-indexed encoder_inputs.
    for length_idx in xrange(encoder_size):
      batch_encoder_inputs.append(np.array([encoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(this_batch_size)], dtype=np.int32))

    # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
    for length_idx in xrange(decoder_size):
      batch_decoder_inputs.append(np.array([decoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(this_batch_size)], dtype=np.int32))

      # Create target_weights to be 0 for targets that are padding.
      batch_weight = np.ones(this_batch_size, dtype=np.float32)
      for batch_idx in xrange(this_batch_size):
        # We set weight to 0 if the corresponding target is a PAD symbol.
        # The corresponding target is decoder_input shifted by 1 forward.
        if length_idx < decoder_size - 1:
          target = decoder_inputs[batch_idx][length_idx + 1]
        if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
          batch_weight[batch_idx] = 0.0
      batch_weights.append(batch_weight)

    ## SHUBHAM - seq_len as nparray and then passing it as well
    seq_len = np.asarray(seq_len, dtype=np.int64)
    return batch_encoder_inputs, batch_decoder_inputs, batch_weights, seq_len
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号