def get_initial_loop_state(self) -> LoopState:
rnn_output_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True,
size=0, name="decoder_outputs")
rnn_output_ta = rnn_output_ta.write(0, self.initial_state)
logit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True,
size=0, name="logits")
outputs_ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True,
size=0, name="outputs")
contexts = [tf.zeros([self.batch_size, a.context_vector_size])
for a in self.attentions]
mask_ta = tf.TensorArray(dtype=tf.bool, dynamic_size=True,
size=0, name="mask")
attn_loop_states = [a.initial_loop_state()
for a in self.attentions if a is not None]
# pylint: disable=not-callable
rnn_feedables = RNNFeedables(
# general:
step=0,
finished=tf.zeros([self.batch_size], dtype=tf.bool),
input_symbol=self.go_symbols,
prev_logits=tf.zeros([self.batch_size, len(self.vocabulary)]),
# rnn-specific:
prev_rnn_state=self.initial_state,
prev_rnn_output=self.initial_state,
prev_contexts=contexts)
rnn_histories = RNNHistories(
attention_histories=attn_loop_states,
# general:
logits=logit_ta,
decoder_outputs=rnn_output_ta,
outputs=outputs_ta,
mask=mask_ta)
# pylint: enable=not-callable
loop_constants = DecoderConstants(train_inputs=self.train_inputs)
return LoopState(
histories=rnn_histories,
constants=loop_constants,
feedables=rnn_feedables)
评论列表
文章目录