def build_lstm(x, size, name, step_size):
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32,
shape=[1, lstm.state_size.c],
name='c_in')
h_in = tf.placeholder(tf.float32,
shape=[1, lstm.state_size.h],
name='h_in')
state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_outputs = tf.reshape(lstm_outputs, [-1, size])
lstm_c, lstm_h = lstm_state
state_out = [lstm_c[:1, :], lstm_h[:1, :]]
return lstm_outputs, state_init, state_in, state_out
python类LSTMStateTuple()的实例源码
tree_encoder.py 文件源码
项目:almond-nnparser
作者: Stanford-Mobisocial-IoT-Lab
项目源码
文件源码
阅读 24
收藏 0
点赞 0
评论 0
def __call__(self, left_state, right_state, extra_input=None):
with tf.variable_scope('TreeLSTM'):
c1, h1 = left_state
c2, h2 = right_state
if extra_input is not None:
input_concat = tf.concat((extra_input, h1, h2), axis=1)
else:
input_concat = tf.concat((h1, h2), axis=1)
concat = tf.layers.dense(input_concat, 5 * self._num_cells)
i, f1, f2, o, g = tf.split(concat, 5, axis=1)
i = tf.sigmoid(i)
f1 = tf.sigmoid(f1)
f2 = tf.sigmoid(f2)
o = tf.sigmoid(o)
g = tf.tanh(g)
cnew = f1 * c1 + f2 * c2 + i * g
hnew = o * cnew
newstate = LSTMStateTuple(c=cnew, h=hnew)
return hnew, newstate
def mask_finished(finished, now_, prev_):
mask = tf.expand_dims(tf.to_float(finished), 1)
if isinstance(prev_, tuple):
# tuple states
next_ = []
for ns, s in zip(now_, prev_):
# fucking LSTMStateTuple
if isinstance(ns, LSTMStateTuple):
next_.append(
LSTMStateTuple(c=(1. - mask) * ns.c + mask * s.c,
h=(1. - mask) * ns.h + mask * s.h))
else:
next_.append((1. - mask) * ns + mask * s)
next_ = tuple(next_)
else:
next_ = (1. - mask) * now_ + mask * prev_
return next_
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(self.scope):
c, h = state
h = dropout(h, self.keep_recurrent_probs, self.is_train)
mat = _compute_gates(inputs, h, self.num_units, self.forget_bias,
self.kernel_initializer, self.recurrent_initializer, True)
i, j, f, o = tf.split(value=mat, num_or_size_splits=4, axis=1)
new_c = (c * self.recurrent_activation(f) + self.recurrent_activation(i) *
self.activation(j))
new_h = self.activation(new_c) * self.recurrent_activation(o)
new_state = LSTMStateTuple(new_c, new_h)
return new_h, new_state
md_lstm.py 文件源码
项目:tensorflow-multi-dimensional-lstm
作者: philipperemy
项目源码
文件源码
阅读 22
收藏 0
点赞 0
评论 0
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM).
@param: inputs (batch,n)
@param state: the states and hidden unit of the two cells
"""
with tf.variable_scope(scope or type(self).__name__):
c1, c2, h1, h2 = state
# change bias argument to False since LN will add bias via shift
concat = _linear([inputs, h1, h2], 5 * self._num_units, False)
i, j, f1, f2, o = tf.split(value=concat, num_or_size_splits=5, axis=1)
# add layer normalization to each gate
i = ln(i, scope='i/')
j = ln(j, scope='j/')
f1 = ln(f1, scope='f1/')
f2 = ln(f2, scope='f2/')
o = ln(o, scope='o/')
new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) +
c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
# add layer_normalization in calculation of new hidden state
new_h = self._activation(ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o)
new_state = LSTMStateTuple(new_c, new_h)
return new_h, new_state
def create_architecture(self, **specs):
self.vars.sequence_length = tf.placeholder(tf.int64, [1], name="sequence_length")
fc_input = self.get_input_layers()
fc1 = fully_connected(fc_input, num_outputs=self.fc_units_num,
scope=self._name_scope + "/fc1")
fc1_reshaped = tf.reshape(fc1, [1, -1, self.fc_units_num])
self.recurrent_cells = self.ru_class(self._recurrent_units_num)
state_c = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.c], name="initial_lstm_state_c")
state_h = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.h], name="initial_lstm_state_h")
self.vars.initial_network_state = LSTMStateTuple(state_c, state_h)
rnn_outputs, self.ops.network_state = tf.nn.dynamic_rnn(self.recurrent_cells,
fc1_reshaped,
initial_state=self.vars.initial_network_state,
sequence_length=self.vars.sequence_length,
time_major=False,
scope=self._name_scope)
reshaped_rnn_outputs = tf.reshape(rnn_outputs, [-1, self._recurrent_units_num])
self.reset_state()
self.ops.pi, self.ops.v = self.policy_value_layer(reshaped_rnn_outputs)
def create_architecture(self):
self.vars.sequence_length = tf.placeholder(tf.int64, [1], name="sequence_length")
fc_input = self.get_input_layers()
fc1 = fully_connected(fc_input,
num_outputs=self.fc_units_num,
scope=self._name_scope + "/fc1")
fc1_reshaped = tf.reshape(fc1, [1, -1, self.fc_units_num])
self.recurrent_cells = self.ru_class(self._recurrent_units_num)
state_c = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.c], name="initial_lstm_state_c")
state_h = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.h], name="initial_lstm_state_h")
self.vars.initial_network_state = LSTMStateTuple(state_c, state_h)
rnn_outputs, self.ops.network_state = tf.nn.dynamic_rnn(self.recurrent_cells,
fc1_reshaped,
initial_state=self.vars.initial_network_state,
sequence_length=self.vars.sequence_length,
time_major=False,
scope=self._name_scope)
reshaped_rnn_outputs = tf.reshape(rnn_outputs, [-1, self._recurrent_units_num])
self.reset_state()
self.ops.pi, self.ops.frameskip_pi, self.ops.v = self.policy_value_frameskip_layer(reshaped_rnn_outputs)
def create_architecture(self):
self.vars.sequence_length = tf.placeholder(tf.int64, [1], name="sequence_length")
fc_input = self.get_input_layers()
fc1 = layers.fully_connected(fc_input, self.fc_units_num, scope=self._name_scope + "/fc1")
fc1_reshaped = tf.reshape(fc1, [1, -1, self.fc_units_num])
self.recurrent_cells = self._get_ru_class()(self._recurrent_units_num)
state_c = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.c], name="initial_lstm_state_c")
state_h = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.h], name="initial_lstm_state_h")
self.vars.initial_network_state = LSTMStateTuple(state_c, state_h)
rnn_outputs, self.ops.network_state = tf.nn.dynamic_rnn(self.recurrent_cells,
fc1_reshaped,
initial_state=self.vars.initial_network_state,
sequence_length=self.vars.sequence_length,
scope=self._name_scope)
reshaped_rnn_outputs = tf.reshape(rnn_outputs, [-1, self._recurrent_units_num])
q = layers.linear(reshaped_rnn_outputs, num_outputs=self.actions_num, scope=self._name_scope + "/q")
self.reset_state()
return q
def __call__(self, input, state, scope=None): # TODO test
with tf.variable_scope(scope or type(self).__name__):
# computation
c_prev, h_prev = state
with tf.variable_scope('mul'):
concat = _linear([input, h_prev], 2 * self._num_units, True)
proj_input, rec_input = tf.split(value=concat, num_or_size_splits=2, axis=1)
mul_input = proj_input * rec_input # equation (18)
with tf.variable_scope('rec_input'):
rec_mul_input = _linear(mul_input, 4 * self._num_units, True)
b = tf.get_variable('b', [self._num_units * 4])
lstm_mat = input + rec_mul_input + b
i, j, f, o = tf.split(value=lstm_mat, num_or_size_splits=4, axis=1)
# new_c, new_h
new_c = (c_prev * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * tf.nn.tanh(j))
new_h = tf.nn.tanh(new_c) * tf.nn.sigmoid(o)
new_state = (LSTMStateTuple(new_c, new_h))
return new_h, new_state
def __init__(self,x,size,step_size):
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32,
shape=[1, lstm.state_size.c],
name='c_in')
h_in = tf.placeholder(tf.float32,
shape=[1, lstm.state_size.h],
name='h_in')
self.state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_outputs = tf.reshape(lstm_outputs, [-1, size])
lstm_c, lstm_h = lstm_state
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.output = lstm_outputs
tree_encoder.py 文件源码
项目:almond-nnparser
作者: Stanford-Mobisocial-IoT-Lab
项目源码
文件源码
阅读 27
收藏 0
点赞 0
评论 0
def zero_state(self, batch_size, dtype=tf.float32):
zeros = tf.zeros((batch_size, self._num_cells), dtype=dtype)
return LSTMStateTuple(zeros, zeros)
def __init__(self, ob_space, ac_space, lstm_size=256, use_categorical_max=False, **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = lstm_size
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_max(self.logits, ac_space)[0, :] \
if use_categorical_max else categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def __init__(self, ob_space, ac_space, lstm_size=256, **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = lstm_size
lnlstm = rnn.LayerNormBasicLSTMCell(size)
self.state_size = lnlstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lnlstm.state_size.c), np.float32)
h_init = np.zeros((1, lnlstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lnlstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lnlstm.state_size.h])
self.state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lnlstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def __init__(self, ob_space, ac_space, lstm_size=256, use_categorical_max=False, **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = lstm_size
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_max(self.logits, ac_space)[0, :] \
if use_categorical_max else categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def __init__(self, ob_space, ac_space, lstm_size=256, **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = lstm_size
lnlstm = rnn.LayerNormBasicLSTMCell(size)
self.state_size = lnlstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lnlstm.state_size.c), np.float32)
h_init = np.zeros((1, lnlstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lnlstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lnlstm.state_size.h])
self.state_in = [c_in, h_in]
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lnlstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def state_size(self):
return LSTMStateTuple(self.output_size, self.output_size)
def get_vector_representations(sess, model, data, save_dir,
batch_size=100,
max_batches=None,
batches_in_epoch=1000,
max_time_diff=float("inf"),
extension=".cell"):
"""
Given a trained model, gets a vector representation for the traces in batch
@param sess is a tensorflow session
@param model is the seq2seq model
@param data is the data (in batch-major form and not padded or a list of files (depending on `in_memory`))
"""
batches = helpers.get_batches(data, batch_size=batch_size)
batches_in_data = len(data) // batch_size
if max_batches is None or batches_in_data < max_batches:
max_batches = batches_in_data - 1
try:
for batch in range(max_batches):
print("Batch {}/{}".format(batch, max_batches))
fd, paths, _ = model.next_batch(batches, False, max_time_diff)
l = sess.run(model.encoder_final_state, fd)
# Returns a tuple, so we concatenate
if isinstance(l, LSTMStateTuple):
l = np.concatenate((l.c, l.h), axis=1)
file_names = [helpers.extract_filename_from_path(path, extension) for path in paths]
for file_name, features in zip(file_names, list(l)):
helpers.write_to_file(features, save_dir, file_name, new_extension=".cellf")
except KeyboardInterrupt:
stdout.write('Interrupted')
exit(0)
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = tf.split(axis=1, num_or_size_splits=2, value=state)
batch_size = tf.shape(inputs)[0]
inputs = tf.reshape(inputs, [batch_size, self.height, self.width, 1])
c = tf.reshape(c, [batch_size, self.height, self.width, self.num_features])
h = tf.reshape(h, [batch_size, self.height, self.width, self.num_features])
concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=concat)
new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * tf.nn.sigmoid(o)
new_h = tf.reshape(new_h, [batch_size, self._num_units])
new_c = tf.reshape(new_c, [batch_size, self._num_units])
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = tf.concat(axis=1, values=[new_c, new_h])
return new_h, new_state
def __init__(self, cell, zoneout_prob, is_training=True):
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not an RNNCell.")
if isinstance(cell, BasicLSTMCell):
self._tuple = lambda x: LSTMStateTuple(*x)
else:
self._tuple = lambda x: tuple(x)
if (isinstance(zoneout_prob, float) and
not (zoneout_prob >= 0.0 and zoneout_prob <= 1.0)):
raise ValueError("Parameter zoneout_prob must be between 0 and 1: %d"
% zoneout_prob)
self._cell = cell
self._zoneout_prob = zoneout_prob
self.is_training = is_training
dynamic_seq2seq_model.py 文件源码
项目:seq2seq_chatterbot
作者: StephenLee2016
项目源码
文件源码
阅读 24
收藏 0
点赞 0
评论 0
def decoder_hidden_units(self):
# @TODO: is this correct for LSTMStateTuple?
return self.decoder_cell.output_size
dynamic_seq2seq_model.py 文件源码
项目:seq2seq_chatterbot
作者: StephenLee2016
项目源码
文件源码
阅读 49
收藏 0
点赞 0
评论 0
def _init_bidirectional_encoder(self):
'''
??LSTM encoder
'''
with tf.variable_scope("BidirectionalEncoder") as scope:
((encoder_fw_outputs,
encoder_bw_outputs),
(encoder_fw_state,
encoder_bw_state)) = (
tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell,
cell_bw=self.encoder_cell,
inputs=self.encoder_inputs_embedded,
sequence_length=self.encoder_inputs_length,
time_major=self.time_major,
dtype=tf.float32)
)
self.encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
if isinstance(encoder_fw_state, LSTMStateTuple):
encoder_state_c = tf.concat(
(encoder_fw_state.c, encoder_bw_state.c), 1, name='bidirectional_concat_c')
encoder_state_h = tf.concat(
(encoder_fw_state.h, encoder_bw_state.h), 1, name='bidirectional_concat_h')
self.encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
elif isinstance(encoder_fw_state, tf.Tensor):
self.encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat')
def initial_states_tuple(self):
"""
Create the initial state tensors for the individual RNN cells.
If no initial state vector was passed to this RNN, all initial states are set to be zero. Otherwise, the initial
state vector is split into a possibly nested tuple of tensors according to the RNN architecture. The return
value of this function is structured in such a way that it can be passed to the `initial_state` parameter of the
RNN functions in `tf.contrib.rnn`.
Returns
-------
tuple of tf.Tensor
A possibly nested tuple of initial state tensors for the RNN cells
"""
if self.initial_state is None:
initial_states = tf.zeros(shape=[self.batch_size, self.state_size], dtype=tf.float32)
else:
initial_states = self.initial_state
initial_states = tuple(tf.split(initial_states, self.num_layers, axis=1))
if self.bidirectional:
initial_states = tuple([tf.split(x, 2, axis=1) for x in initial_states])
initial_states_fw, initial_states_bw = zip(*initial_states)
if self.cell_type == CellType.LSTM:
initial_states_fw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states_fw])
initial_states_bw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states_bw])
initial_states = (initial_states_fw, initial_states_bw)
else:
if self.cell_type == CellType.LSTM:
initial_states = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
for lstm_state in initial_states])
return initial_states
def decoder_hidden_units(self):
# @TODO: is this correct for LSTMStateTuple?
return self.decoder_cell.output_size
def __init__(self, ob_space, ac_space):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def state_size(self):
return LSTMStateTuple(self.num_units, self.num_units)
def convert_to_state(self, variables):
if len(variables) != 2:
raise ValueError()
return LSTMStateTuple(variables[0], variables[1])
def convert_to_state(self, variables):
if len(variables) != 2:
raise ValueError()
return LSTMStateTuple(variables[0], variables[1])
def convert_to_state(self, variables):
if len(variables) != 2:
raise ValueError()
return LSTMStateTuple(variables[0], variables[1])
def _init(self, inputs, num_outputs, options):
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))
self.x = x = inputs
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# Introduce a "fake" batch dimension of 1 after flatten so that we can
# do LSTM over the time dim.
x = tf.expand_dims(flatten(x), [0])
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, x,
initial_state=state_in,
sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_out, [-1, size])
logits = linear(x, num_outputs, "action", normc_initializer(0.01))
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
return logits, x