seq_labeling.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:joint-slu-lm 作者: HadoopIt 项目源码 文件源码
def generate_task_output(encoder_outputs, additional_inputs, encoder_state, targets,sequence_length, num_decoder_symbols, weights,
                       buckets, softmax_loss_function=None,
                       per_example_loss=False, name=None, use_attention=False, scope=None, DNN_at_output=False, 
                       intent_results=None, 
                       tagging_results=None, 
                       train_with_true_label=True,
                       use_local_context=False,
                       forward_only=False):
  if len(targets) < buckets[-1][1]:
    raise ValueError("Length of targets (%d) must be at least that of last"
                     "bucket (%d)." % (len(targets), buckets[-1][1]))

  all_inputs = encoder_outputs + targets + weights
  with ops.op_scope(all_inputs, name, "model_with_buckets"):
    if scope == 'intent':
        logits, regularizers, sampled_intents = intent_results
        sampled_tags = list()
    elif scope == 'tagging':
        logits, regularizers, sampled_tags = tagging_results
        sampled_intents = list()
    elif scope == 'lm':
      with variable_scope.variable_scope(scope + "_generate_sequence_output", reuse=None):
        task_inputs = []
        if use_local_context:
          print ('lm task: use sampled_tag_intent_emb as local context')
          task_inputs = [array_ops.concat(1, [additional_input, encoder_output]) for additional_input, encoder_output in zip(additional_inputs, encoder_outputs)]
        else:
          task_inputs = encoder_outputs        

        logits, _, regularizers = generate_sequence_output(task_inputs, 
                                                            encoder_state,
                                                            num_decoder_symbols,
                                                            sequence_length,
                                                            use_attention=use_attention,
                                                            DNN_at_output=DNN_at_output,
                                                            forward_only=forward_only)

        sampled_tags = list()
        sampled_intents = list()             

    if per_example_loss is None:
      assert len(logits) == len(targets)
      # We need to make target and int64-tensor and set its shape.
      bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
      crossent = sequence_loss_by_example(
            logits, bucket_target, weights,
            softmax_loss_function=softmax_loss_function)
    else:
      assert len(logits) == len(targets)
      bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
      crossent = sequence_loss(
            logits, bucket_target, weights,
            softmax_loss_function=softmax_loss_function)
      crossent_with_regularizers = crossent + 1e-4 * regularizers

  return logits, sampled_tags, sampled_intents, crossent_with_regularizers, crossent
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号