def lstm_seq2seq_internal(inputs, targets, hparams, train):
"""The basic LSTM seq2seq model, main step used for training."""
with tf.variable_scope("lstm_seq2seq"):
if inputs is not None:
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
_, final_encoder_state = lstm(
tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
else:
final_encoder_state = None
# LSTM decoder.
shifted_targets = common_layers.shift_right(targets)
decoder_outputs, _ = lstm(
common_layers.flatten4d3d(shifted_targets),
hparams,
train,
"decoder",
initial_state=final_encoder_state)
return tf.expand_dims(decoder_outputs, axis=2)
评论列表
文章目录