model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Video-Captioning 作者: hehefan 项目源码 文件源码
def get_batch(self, features, sentences, lengths):
    batch_size = len(sentences)
    encoder_inputs, encoder_lengths, decoder_inputs = [], [], []
    feature_pad = np.array([0.0] * self.feature_size)
    for (vid, sen) in sentences:
      feature = features[vid]
      encoder_lengths.append(lengths[vid])
      if len(feature) > self.encoder_max_sequence_length:
        feature = random.sample(feature, self.encoder_max_sequence_length)
      pad_size = self.encoder_max_sequence_length - len(feature)
      encoder_inputs.append(feature + [feature_pad] * pad_size)

      pad_size = self.decoder_max_sentence_length - len(sen) - 2
      decoder_inputs.append([data_utils.GO_ID] + sen + [data_utils.EOS_ID] + [data_utils.PAD_ID] * pad_size)

    batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
    for length_idx in xrange(self.encoder_max_sequence_length):
      batch_encoder_inputs.append(np.array([encoder_inputs[batch_idx][length_idx] for batch_idx in xrange(batch_size)], dtype=np.float32))
    batch_encoder_lengths = np.array(encoder_lengths)
    for length_idx in xrange(self.decoder_max_sentence_length):
      batch_decoder_inputs.append(np.array([decoder_inputs[batch_idx][length_idx] for batch_idx in xrange(batch_size)], dtype=np.int32))
      # Create target_weights to be 0 for targets that are padding.
      batch_weight = np.ones(batch_size, dtype=np.float32)
      for batch_idx in xrange(batch_size):
        if length_idx < self.decoder_max_sentence_length - 1:
          target = decoder_inputs[batch_idx][length_idx + 1]
        if length_idx == self.decoder_max_sentence_length - 1 or target == data_utils.PAD_ID:
          batch_weight[batch_idx] = 0.0
      batch_weights.append(batch_weight)
    return batch_encoder_inputs, batch_encoder_lengths, batch_decoder_inputs, batch_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号