def initial_loop_state(self) -> MultiHeadLoopStateTA:
return MultiHeadLoopStateTA(
contexts=tf.TensorArray(
dtype=tf.float32, size=0, dynamic_size=True,
name="contexts"),
head_weights=[tf.TensorArray(
dtype=tf.float32, size=0, dynamic_size=True,
name="distributions_head{}".format(i), clear_after_read=False)
for i in range(self.n_heads)])
评论列表
文章目录