many2one_seq2seq.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def many2one_model_with_buckets(encoder_inputs_list, decoder_inputs, targets, weights,
                       buckets, seq2seq, softmax_loss_function=None,
                       per_example_loss=False, name=None, spscale=20):

  # Modified model with buckets to accept 2 encoders

  if len(encoder_inputs_list[0]) < buckets[-1][0]:
    raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
                     "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
  if len(targets) < buckets[-1][1]:
    raise ValueError("Length of targets (%d) must be at least that of last"
                     "bucket (%d)." % (len(targets), buckets[-1][1]))
  if len(weights) < buckets[-1][1]:
    raise ValueError("Length of weights (%d) must be at least that of last"
                     "bucket (%d)." % (len(weights), buckets[-1][1]))

  all_inputs = encoder_inputs_list + decoder_inputs + targets + weights
  losses = []
  outputs = []
  speech_buckets = [(x*spscale, y) for (x,y) in buckets]
  with ops.op_scope(all_inputs, name, "many2one_model_with_buckets"):
    for j, bucket in enumerate(buckets):
      with variable_scope.variable_scope(variable_scope.get_variable_scope(),
                                         reuse=True if j > 0 else None):
        #bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], decoder_inputs[:bucket[1]])
        x = encoder_inputs_list[0][:bucket[0]]
        #print( x )
        y = encoder_inputs_list[1][:speech_buckets[j][0]]
        bucket_outputs, _ = seq2seq([x, y], decoder_inputs[:bucket[1]])
        outputs.append(bucket_outputs)
        if per_example_loss:
          losses.append(sequence_loss_by_example(
              outputs[-1], targets[:bucket[1]], weights[:bucket[1]],
              softmax_loss_function=softmax_loss_function))
        else:
          losses.append(sequence_loss(
              outputs[-1], targets[:bucket[1]], weights[:bucket[1]],
              softmax_loss_function=softmax_loss_function))

  return outputs, losses
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号