def apply(self, is_train, inputs, mask=None):
fw = self.fw(is_train)
bw_spec = self.fw if self.bw is None else self.bw
bw = bw_spec(is_train)
if self.merge is None:
return tf.concat(bidirectional_dynamic_rnn(fw, bw, inputs, mask, swap_memory=self.swap_memory,
dtype=tf.float32)[0], 2,)
else:
fw, bw = bidirectional_dynamic_rnn(fw, bw, inputs, mask,
swap_memory=self.swap_memory, dtype=tf.float32)[0]
return self.merge.apply(is_train, fw, bw) # TODO this should be in a different scope
python类bidirectional_dynamic_rnn()的实例源码
def _add_encoders(self):
with tf.variable_scope('query_encoder'):
query_encoder_cell = GRUCell(self.encoder_cell_state_size)
if self.dropout_enabled and self.mode != 'decode':
query_encoder_cell = DropoutWrapper(cell=query_encoder_cell, output_keep_prob=0.8)
query_embeddings = tf.nn.embedding_lookup(self.embeddings, self.queries_placeholder)
query_encoder_outputs, _ = rnn.dynamic_rnn(query_encoder_cell, query_embeddings,
sequence_length=self.query_lengths_placeholder,
swap_memory=True, dtype=tf.float32)
self.query_last = query_encoder_outputs[:, -1, :]
with tf.variable_scope('encoder'):
fw_cell = GRUCell(self.encoder_cell_state_size)
bw_cell = GRUCell(self.encoder_cell_state_size)
if self.dropout_enabled and self.mode != 'decode':
fw_cell = DropoutWrapper(cell=fw_cell, output_keep_prob=0.8)
bw_cell = DropoutWrapper(cell=bw_cell, output_keep_prob=0.8)
embeddings = tf.nn.embedding_lookup(self.embeddings, self.documents_placeholder)
(encoder_outputs_fw, encoder_outputs_bw), _ = rnn.bidirectional_dynamic_rnn(
fw_cell, bw_cell,
embeddings,
sequence_length=self.document_lengths_placeholder,
swap_memory=True,
dtype=tf.float32)
self.encoder_outputs = tf.concat([encoder_outputs_fw, encoder_outputs_bw], 2)
self.final_encoder_state = self.encoder_outputs[:, -1, :]
def build_multi_dynamic_brnn(args,
maxTimeSteps,
inputX,
cell_fn,
seqLengths,
time_major=True):
hid_input = inputX
for i in range(args.num_layer):
scope = 'DBRNN_' + str(i + 1)
forward_cell = cell_fn(args.num_hidden, activation=args.activation)
backward_cell = cell_fn(args.num_hidden, activation=args.activation)
# tensor of shape: [max_time, batch_size, input_size]
outputs, output_states = bidirectional_dynamic_rnn(forward_cell, backward_cell,
inputs=hid_input,
dtype=tf.float32,
sequence_length=seqLengths,
time_major=True,
scope=scope)
# forward output, backward ouput
# tensor of shape: [max_time, batch_size, input_size]
output_fw, output_bw = outputs
# forward states, backward states
output_state_fw, output_state_bw = output_states
# output_fb = tf.concat(2, [output_fw, output_bw])
output_fb = tf.concat([output_fw, output_bw], 2)
shape = output_fb.get_shape().as_list()
output_fb = tf.reshape(output_fb, [shape[0], shape[1], 2, int(shape[2] / 2)])
hidden = tf.reduce_sum(output_fb, 2)
hidden = dropout(hidden, args.keep_prob, (args.mode == 'train'))
if i != args.num_layer - 1:
hid_input = hidden
else:
outputXrs = tf.reshape(hidden, [-1, args.num_hidden])
# output_list = tf.split(0, maxTimeSteps, outputXrs)
output_list = tf.split(outputXrs, maxTimeSteps, 0)
fbHrs = [tf.reshape(t, [args.batch_size, args.num_hidden]) for t in output_list]
return fbHrs
def apply(self, is_train, x, mask=None):
states = bidirectional_dynamic_rnn(self.cell_spec(is_train), self.cell_spec(is_train), x, mask, dtype=tf.float32)[1]
output = []
for state in states:
for i,x in enumerate(state._fields):
if x == self.output:
output.append(state[i])
if self.merge is not None:
return self.merge.apply(is_train, output[0], output[1])
else:
return tf.concat(output, axis=1)
def _apply(self, X, state=None, memory=None):
# time_major: The shape format of the `inputs` and `outputs` Tensors.
# If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
# If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
# ====== create attention if necessary ====== #
cell = self.cell
if self.bidirectional:
cell_bw = self.cell_bw
# create attention cell
if self.attention:
if not hasattr(self, "_cell_with_attention"):
self._cell_with_attention = self.__attention_creator(
cell, X=X, memory=memory)
cell = self._cell_with_attention
# bidirectional attention
if self.bidirectional:
if not hasattr(self, "_cell_with_attention_bw"):
self._cell_with_attention_bw = self.__attention_creator(
cell_bw, X=X, memory=memory)
cell_bw = self._cell_with_attention_bw
# ====== calling rnn_warpper ====== #
## Bidirectional
if self.bidirectional:
rnn_func = rnn.bidirectional_dynamic_rnn if self.dynamic \
else rnn.static_bidirectional_rnn
state_fw, state_bw = None, None
if isinstance(state, (tuple, list)):
state_fw = state[0]
if len(state) > 1:
state_bw = state[1]
else:
state_fw = state
outputs = rnn_func(cell_fw=cell, cell_bw=cell_bw, inputs=X,
initial_state_fw=state_fw,
initial_state_bw=state_bw,
dtype=X.dtype.base_dtype)
## Unidirectional
else:
rnn_func = rnn.dynamic_rnn if self.dynamic else rnn.static_rnn
outputs = rnn_func(cell, inputs=X, initial_state=state,
dtype=X.dtype.base_dtype)
# ====== initialize cell ====== #
if not self._is_initialized_variables:
# initialize only once, everytime you call this, the values of
# variables changed
K.eval(tf.variables_initializer(self.variables))
self._is_initialized_variables = True
_infer_variable_role(self.variables)
# ====== return ====== #
if self.bidirectional: # concat outputs
outputs = (tf.concat(outputs[0], axis=-1), outputs[1])
if not self.return_states:
return outputs[0]
return outputs