network.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tensorflow-rl 作者: steveKapturowski 项目源码 文件源码
def _build_encoder(self):
        with tf.variable_scope(self.name):
            if self.arch == 'FC':
                layer_i = layers.flatten(self.input_ph)
                for i, layer_size in enumerate(self.fc_layer_sizes):
                    layer_i = layers.fc('fc{}'.format(i+1), layer_i, layer_size, activation=self.activation)[-1]
                self.ox = layer_i
            elif self.arch == 'ATARI-TRPO':
                self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 4, self.input_channels, 2, activation=self.activation)
                self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 16, 4, 16, 2, activation=self.activation)
                self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 20, activation=self.activation)
                self.ox = self.o3
            elif self.arch == 'NIPS':
                self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 8, self.input_channels, 4, activation=self.activation)
                self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 32, 4, 16, 2, activation=self.activation)
                self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 256, activation=self.activation)
                self.ox = self.o3
            elif self.arch == 'NATURE':
                self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 32, 8, self.input_channels, 4, activation=self.activation)
                self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 64, 4, 32, 2, activation=self.activation)
                self.w3, self.b3, self.o3 = layers.conv2d('conv3', self.o2, 64, 3, 64, 1, activation=self.activation)
                self.w4, self.b4, self.o4 = layers.fc('fc4', layers.flatten(self.o3), 512, activation=self.activation)
                self.ox = self.o4
            else:
                raise Exception('Invalid architecture `{}`'.format(self.arch))

            if self.use_recurrent:
                with tf.variable_scope('lstm_layer') as vs:
                    self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(
                        self.hidden_state_size, state_is_tuple=True, forget_bias=1.0)

                    batch_size = tf.shape(self.step_size)[0]
                    self.ox_reshaped = tf.reshape(self.ox,
                        [batch_size, -1, self.ox.get_shape().as_list()[-1]])
                    state_tuple = tf.contrib.rnn.LSTMStateTuple(
                        *tf.split(self.initial_lstm_state, 2, 1))

                    self.lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn(
                        self.lstm_cell,
                        self.ox_reshaped,
                        initial_state=state_tuple,
                        sequence_length=self.step_size,
                        time_major=False)

                    self.lstm_state = tf.concat(self.lstm_state, 1)
                    self.ox = tf.reshape(self.lstm_outputs, [-1,self.hidden_state_size], name='reshaped_lstm_outputs')

                    # Get all LSTM trainable params
                    self.lstm_trainable_variables = [v for v in 
                        tf.trainable_variables() if v.name.startswith(vs.name)]

            return self.ox
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号