def build(self):
self._define_input()
output = self.input_seq
output = embedding(output, self.vocab.size, self.embedding_dim, name='layer_embedding')
input_dim = self.embedding_dim
# Prepare data shape to match rnn function requirements
# Current data input shape: [batch_size, num_steps, input_dim]
# Required shape: 'num_steps' tensors list of shape [batch_size, input_dim]
output = tf.transpose(output, [1, 0, 2])
output = tf.reshape(output, [-1, input_dim])
output = tf.split(output, self.num_steps, 0)
if self.bidirectional:
# 'num_steps' tensors list of shape [batch_size, rnn_units * 2]
fw_cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
bw_cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
output, state_fw, state_bw = rnn.static_bidirectional_rnn(
fw_cell, bw_cell, output, dtype=tf.float32, sequence_length=self.seq_len, scope='encoder')
if isinstance(state_fw, tf.contrib.rnn.LSTMStateTuple):
encoder_state_c = tf.concat([state_fw.c, state_bw.c], axis=1, name='bidirectional_concat_c')
encoder_state_h = tf.concat([state_fw.h, state_bw.h], axis=1, name='bidirectional_concat_h')
state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
elif isinstance(state_fw, tf.Tensor):
state = tf.concat([state_fw, state_bw], axis=1, name='bidirectional_concat')
else:
raise ValueError
else:
# 'num_steps' tensors list of shape [batch_size, rnn_units]
cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
output, state = rnn.static_rnn(cell, output, dtype=tf.float32, sequence_length=self.seq_len,
scope='encoder')
output = tf.stack(output, axis=0) # [num_steps, batch_size, rnn_units]
output = tf.transpose(output, [1, 0, 2]) # [batch_size, num_steps, rnn_units]
self.encoder_output = output
self.encoder_state = state
return output, state
评论列表
文章目录