gait_nn.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:gait-recognition 作者: marian-margeta 项目源码 文件源码
def get_network(self, input_tensor, is_training, reuse = False):
        net = input_tensor

        with tf.variable_scope('GaitNN', reuse = reuse):
            with slim.arg_scope(self.get_arg_scope(is_training)):
                with tf.variable_scope('DownSampling'):
                    with tf.variable_scope('17x17'):
                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
                        slim.repeat(net, 3, self.residual_block, ch = 256, ch_inner = 64)

                    with tf.variable_scope('8x8'):
                        net = self.residual_block(net, ch = 512, ch_inner = 64, stride = 2)
                        slim.repeat(net, 2, self.residual_block, ch = 512, ch_inner = 128)

                    with tf.variable_scope('4x4'):
                        net = self.residual_block(net, ch = 512, ch_inner = 128, stride = 2)
                        slim.repeat(net, 1, self.residual_block, ch = 512, ch_inner = 256)

                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 3)

                with tf.variable_scope('FullyConnected'):
                    # net = tf.reduce_mean(net, [1, 2], name = 'GlobalPool')
                    net = layers.flatten(net)
                    net = layers.fully_connected(net, 512, activation_fn = None, normalizer_fn = None)

                with tf.variable_scope('Recurrent', initializer = tf.contrib.layers.xavier_initializer()):
                    cell_type = {
                        'GRU': tf.nn.rnn_cell.GRUCell,
                        'LSTM': tf.nn.rnn_cell.LSTMCell
                    }

                    cell = cell_type[self.recurrent_unit](self.FEATURES)
                    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.rnn_layers, state_is_tuple = True)

                    net = tf.expand_dims(net, 0)
                    net, state = tf.nn.dynamic_rnn(cell, net, initial_state = cell.zero_state(1, dtype = tf.float32))
                    net = tf.reshape(net, [-1, self.FEATURES])

                    # Temporal Avg-Pooling
                    gait_signature = tf.reduce_mean(net, 0)

                if is_training:
                    net = tf.expand_dims(gait_signature, 0)
                    net = layers.dropout(net, 0.7)

                    with tf.variable_scope('Logits'):
                        net = layers.fully_connected(net, self.num_of_persons, activation_fn = None,
                                                     normalizer_fn = None)

                return net, gait_signature, state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号