def initial_states_tuple(self):
"""
Create the initial state tensors for the individual RNN cells.
If no initial state vector was passed to this RNN, all initial states are set to be zero. Otherwise, the initial
state vector is split into a possibly nested tuple of tensors according to the RNN architecture. The return
value of this function is structured in such a way that it can be passed to the `initial_state` parameter of the
RNN functions in `tf.contrib.rnn`.
Returns
-------
tuple of tf.Tensor
A possibly nested tuple of initial state tensors for the RNN cells
"""
if self.initial_state is None:
initial_states = tf.zeros(shape=[self.batch_size, self.state_size], dtype=tf.float32)
else:
initial_states = self.initial_state
initial_states = tuple(tf.split(initial_states, self.num_layers, axis=1))
if self.bidirectional:
initial_states = tuple([tf.split(x, 2, axis=1) for x in initial_states])
initial_states_fw, initial_states_bw = zip(*initial_states)
if self.cell_type == CellType.LSTM:
initial_states_fw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states_fw])
initial_states_bw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states_bw])
initial_states = (initial_states_fw, initial_states_bw)
else:
if self.cell_type == CellType.LSTM:
initial_states = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states])
return initial_states
评论列表
文章目录