def apply(self, is_train, x, mask=None):
states = bidirectional_dynamic_rnn(self.cell_spec(is_train), self.cell_spec(is_train), x, mask, dtype=tf.float32)[1]
output = []
for state in states:
for i,x in enumerate(state._fields):
if x == self.output:
output.append(state[i])
if self.merge is not None:
return self.merge.apply(is_train, output[0], output[1])
else:
return tf.concat(output, axis=1)
评论列表
文章目录