def apply(self, is_train, x, mask=None):
state = dynamic_rnn(self.cell_spec(is_train), x, mask, dtype=tf.float32)[1]
if isinstance(self.output, int):
return state[self.output]
else:
if self.output is None:
if not isinstance(state, tf.Tensor):
raise ValueError()
return state
for i,x in enumerate(state._fields):
if x == self.output:
return state[i]
raise ValueError()
评论列表
文章目录