def generate_task_output(encoder_outputs, additional_inputs, encoder_state, targets,sequence_length, num_decoder_symbols, weights,
buckets, softmax_loss_function=None,
per_example_loss=False, name=None, use_attention=False, scope=None, DNN_at_output=False,
intent_results=None,
tagging_results=None,
train_with_true_label=True,
use_local_context=False,
forward_only=False):
if len(targets) < buckets[-1][1]:
raise ValueError("Length of targets (%d) must be at least that of last"
"bucket (%d)." % (len(targets), buckets[-1][1]))
all_inputs = encoder_outputs + targets + weights
with ops.op_scope(all_inputs, name, "model_with_buckets"):
if scope == 'intent':
logits, regularizers, sampled_intents = intent_results
sampled_tags = list()
elif scope == 'tagging':
logits, regularizers, sampled_tags = tagging_results
sampled_intents = list()
elif scope == 'lm':
with variable_scope.variable_scope(scope + "_generate_sequence_output", reuse=None):
task_inputs = []
if use_local_context:
print ('lm task: use sampled_tag_intent_emb as local context')
task_inputs = [array_ops.concat(1, [additional_input, encoder_output]) for additional_input, encoder_output in zip(additional_inputs, encoder_outputs)]
else:
task_inputs = encoder_outputs
logits, _, regularizers = generate_sequence_output(task_inputs,
encoder_state,
num_decoder_symbols,
sequence_length,
use_attention=use_attention,
DNN_at_output=DNN_at_output,
forward_only=forward_only)
sampled_tags = list()
sampled_intents = list()
if per_example_loss is None:
assert len(logits) == len(targets)
# We need to make target and int64-tensor and set its shape.
bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
crossent = sequence_loss_by_example(
logits, bucket_target, weights,
softmax_loss_function=softmax_loss_function)
else:
assert len(logits) == len(targets)
bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
crossent = sequence_loss(
logits, bucket_target, weights,
softmax_loss_function=softmax_loss_function)
crossent_with_regularizers = crossent + 1e-4 * regularizers
return logits, sampled_tags, sampled_intents, crossent_with_regularizers, crossent
评论列表
文章目录