recognition_model.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:GestureRecognition 作者: gkchai 项目源码 文件源码
def create_base(self, inputs, is_training):

        def single_cell(size):
            if is_training:
                return tf.contrib.rnn.DropoutWrapper(LSTMCell(size),
                                                     output_keep_prob=self._config.keep_prob)
            else:
                return tf.contrib.rnn.DropoutWrapper(LSTMCell(size), 1.0)

        with tf.name_scope('Model'):

            cell = tf.contrib.rnn.MultiRNNCell([single_cell(size) for size in self._config.lstm_params['hidden_sizes']])
            cell.zero_state(self._config.batch_size, tf.float32)

            input_list = tf.unstack(tf.expand_dims(inputs, axis=2), axis=1)
            outputs, _ = tf.nn.static_rnn(cell, input_list, dtype=tf.float32)

            # take the last output in the sequence
            output = outputs[-1]

            with tf.name_scope("final_layer"):
                with tf.name_scope("Wx_plus_b"):
                    softmax_w = tf.get_variable("softmax_w", [self._config.lstm_params['hidden_sizes'][-1], self._config.num_classes],
                                                initializer=tf.contrib.layers.xavier_initializer())
                    softmax_b = tf.get_variable("softmax_b", [self._config.num_classes],
                                                initializer=tf.constant_initializer(0.1))
                    logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b, "logits")

            with tf.name_scope('output'):
                predicted_classes = tf.to_int32(tf.argmax(logits, dimension=1), name='y')

        return logits, predicted_classes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号