def lstm(
inputs,
num_units,
num_layers=1,
initializer_fn=tf.truncated_normal,
initializer_params=_default_initializer_params,
dtype=tf.float32,
scope=None
):
print('input shape', inputs.get_shape())
shape = inputs.get_shape().as_list()
batch_size = shape[0]
inputs_unpacked = tf.unpack(inputs, axis=1)
cell = tf.contrib.rnn.python.ops.lstm_ops.LSTMBlockCell(num_units=num_units)
print('cell state size', cell.state_size)
if num_layers > 1:
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
initializer_params = initializer_params or {}
initializer_params['dtype'] = dtype
if isinstance(cell.state_size, tuple):
initial_state = tuple(initializer_fn([batch_size, s]) for s in cell.state_size)
else:
initial_state = initializer_fn(shape=[batch_size, cell.state_size], **initializer_params)
outputs, _, _ = tf.nn.rnn(
cell,
inputs_unpacked,
initial_state=initial_state,
dtype=dtype,
scope=scope)
outputs = tf.pack(outputs, axis=1)
print('output shape', outputs.get_shape())
return outputs
评论列表
文章目录