def _apply(self, X, state=None, memory=None):
# time_major: The shape format of the `inputs` and `outputs` Tensors.
# If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
# If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
# ====== create attention if necessary ====== #
cell = self.cell
if self.bidirectional:
cell_bw = self.cell_bw
# create attention cell
if self.attention:
if not hasattr(self, "_cell_with_attention"):
self._cell_with_attention = self.__attention_creator(
cell, X=X, memory=memory)
cell = self._cell_with_attention
# bidirectional attention
if self.bidirectional:
if not hasattr(self, "_cell_with_attention_bw"):
self._cell_with_attention_bw = self.__attention_creator(
cell_bw, X=X, memory=memory)
cell_bw = self._cell_with_attention_bw
# ====== calling rnn_warpper ====== #
## Bidirectional
if self.bidirectional:
rnn_func = rnn.bidirectional_dynamic_rnn if self.dynamic \
else rnn.static_bidirectional_rnn
state_fw, state_bw = None, None
if isinstance(state, (tuple, list)):
state_fw = state[0]
if len(state) > 1:
state_bw = state[1]
else:
state_fw = state
outputs = rnn_func(cell_fw=cell, cell_bw=cell_bw, inputs=X,
initial_state_fw=state_fw,
initial_state_bw=state_bw,
dtype=X.dtype.base_dtype)
## Unidirectional
else:
rnn_func = rnn.dynamic_rnn if self.dynamic else rnn.static_rnn
outputs = rnn_func(cell, inputs=X, initial_state=state,
dtype=X.dtype.base_dtype)
# ====== initialize cell ====== #
if not self._is_initialized_variables:
# initialize only once, everytime you call this, the values of
# variables changed
K.eval(tf.variables_initializer(self.variables))
self._is_initialized_variables = True
_infer_variable_role(self.variables)
# ====== return ====== #
if self.bidirectional: # concat outputs
outputs = (tf.concat(outputs[0], axis=-1), outputs[1])
if not self.return_states:
return outputs[0]
return outputs
评论列表
文章目录