def _build_encoder(self):
with tf.variable_scope(self.name):
if self.arch == 'FC':
layer_i = layers.flatten(self.input_ph)
for i, layer_size in enumerate(self.fc_layer_sizes):
layer_i = layers.fc('fc{}'.format(i+1), layer_i, layer_size, activation=self.activation)[-1]
self.ox = layer_i
elif self.arch == 'ATARI-TRPO':
self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 4, self.input_channels, 2, activation=self.activation)
self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 16, 4, 16, 2, activation=self.activation)
self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 20, activation=self.activation)
self.ox = self.o3
elif self.arch == 'NIPS':
self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 8, self.input_channels, 4, activation=self.activation)
self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 32, 4, 16, 2, activation=self.activation)
self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 256, activation=self.activation)
self.ox = self.o3
elif self.arch == 'NATURE':
self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 32, 8, self.input_channels, 4, activation=self.activation)
self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 64, 4, 32, 2, activation=self.activation)
self.w3, self.b3, self.o3 = layers.conv2d('conv3', self.o2, 64, 3, 64, 1, activation=self.activation)
self.w4, self.b4, self.o4 = layers.fc('fc4', layers.flatten(self.o3), 512, activation=self.activation)
self.ox = self.o4
else:
raise Exception('Invalid architecture `{}`'.format(self.arch))
if self.use_recurrent:
with tf.variable_scope('lstm_layer') as vs:
self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(
self.hidden_state_size, state_is_tuple=True, forget_bias=1.0)
batch_size = tf.shape(self.step_size)[0]
self.ox_reshaped = tf.reshape(self.ox,
[batch_size, -1, self.ox.get_shape().as_list()[-1]])
state_tuple = tf.contrib.rnn.LSTMStateTuple(
*tf.split(self.initial_lstm_state, 2, 1))
self.lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn(
self.lstm_cell,
self.ox_reshaped,
initial_state=state_tuple,
sequence_length=self.step_size,
time_major=False)
self.lstm_state = tf.concat(self.lstm_state, 1)
self.ox = tf.reshape(self.lstm_outputs, [-1,self.hidden_state_size], name='reshaped_lstm_outputs')
# Get all LSTM trainable params
self.lstm_trainable_variables = [v for v in
tf.trainable_variables() if v.name.startswith(vs.name)]
return self.ox
评论列表
文章目录