def _create_cell(self, seq, no_stacked_cells):
"""
Creates GRU cell
:param seq: placeholder of the input batch
:return: cell and placeholder for its internal state
"""
batch_size = tf.shape(seq)[0]
# Since around May 2017, there is new way of constructing MultiRNNCell
cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.GRUCell(self.hidden_size) for _ in range(no_stacked_cells)])
multi_cell_zero_state = cell.zero_state(batch_size, tf.float32)
in_state_shape = tuple([None, self.hidden_size] for _ in range(no_stacked_cells))
in_state = tuple(tf.placeholder_with_default(cell_zero_state, [None, self.hidden_size], name='in_state') for cell_zero_state in multi_cell_zero_state)
return cell, in_state
评论列表
文章目录