def build_rnn_layer(self, inputs_, train_phase):
"""
Build the computation graph from inputs to outputs of the RNN layer.
:param inputs_: [batch, t, emb], float
:param train_phase: bool
:return: rnn_outputs_: [batch, t, hid_dim], float
"""
config = self.config
def unrolled_rnn(cell, emb_inputs_, initial_state_, seq_len_):
if not config.fix_seq_len:
raise Exception("`config.fix_seq_len` should be set to `True` if using unrolled_rnn()")
outputs = []
state = initial_state_
with tf.variable_scope("unrolled_rnn"):
for t in range(config.max_seq_len):
if t > 0:
tf.get_variable_scope().reuse_variables()
output, state = cell(emb_inputs_[:, t], state) # [batch, hid_dim]
outputs.append(output)
rnn_outputs_ = tf.pack(outputs, axis=1) # [batch, t, hid_dim]
return rnn_outputs_
def dynamic_rnn(cell, emb_inputs_, initial_state_, seq_len_):
rnn_outputs_, last_states_ = tf.nn.dynamic_rnn(cell, emb_inputs_, initial_state=initial_state_,
sequence_length=seq_len_,
dtype=config.float_type) # you should define dtype if initial_state is not provided
return rnn_outputs_
def bidirectional_rnn(cell, emb_inputs_, initial_state_, seq_len_):
rnn_outputs_, output_states = tf.nn.bidirectional_dynamic_rnn(cell, cell, emb_inputs_, seq_len_,
initial_state_, initial_state_, config.float_type)
return tf.concat(2, rnn_outputs_)
def rnn(cell, emb_inputs_, initial_state_, seq_len_):
if not config.fix_seq_len:
raise Exception("`config.fix_seq_len` should be set to `True` if using rnn()")
inputs_ = tf.unpack(emb_inputs_, axis=1)
outputs_, states_ = tf.nn.rnn(cell, inputs_, initial_state_, dtype=config.float_type, sequence_length=seq_len_)
return outputs_
if config.rnn == 'rnn':
cell = tf.nn.rnn_cell.BasicRNNCell(config.hidden_dim)
elif config.rnn == 'lstm':
cell = tf.nn.rnn_cell.BasicLSTMCell(config.hidden_dim)
elif config.rnn == 'gru':
cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
else:
raise Exception("`config.rnn` should be correctly defined.")
if train_phase and config.keep_prob < 1:
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=config.keep_prob)
if config.num_layers is not None and config.num_layers > 1:
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * config.num_layers)
initial_state_ = cell.zero_state(config.batch_size, dtype=config.float_type)
if config.use_seq_len_in_rnn:
seq_len_ = self.seq_len_
else:
seq_len_ = None
rnn_outputs_ = dynamic_rnn(cell, inputs_, initial_state_, seq_len_) # [batch, time, hid_dim]
return rnn_outputs_
评论列表
文章目录