def RNN(self, inputs):
""""""
input_size = inputs.get_shape().as_list()[-1]
cell = self.recur_cell(self._config, input_size=input_size, moving_params=self.moving_params)
lengths = tf.reshape(tf.to_int64(self.sequence_lengths), [-1])
if self.moving_params is None:
ff_keep_prob = self.ff_keep_prob
recur_keep_prob = self.recur_keep_prob
else:
ff_keep_prob = 1
recur_keep_prob = 1
if self.recur_bidir:
top_recur, fw_recur, bw_recur = rnn.dynamic_bidirectional_rnn(cell, cell, inputs,
lengths,
ff_keep_prob=ff_keep_prob,
recur_keep_prob=recur_keep_prob,
dtype=tf.float32)
fw_cell, fw_out = tf.split(1, 2, fw_recur)
bw_cell, bw_out = tf.split(1, 2, bw_recur)
end_recur = tf.concat(1, [fw_out, bw_out])
top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(2*self.recur_size)])
else:
top_recur, end_recur = rnn.dynamic_rnn(cell, inputs,
lengths,
ff_keep_prob=ff_keep_prob,
recur_keep_prob=recur_keep_prob,
dtype=tf.float32)
top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.recur_size)])
return top_recur, end_recur
#=============================================================
评论列表
文章目录