model.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Video-Pixel-Networks 作者: 3ammor 项目源码 文件源码
def build_model(self):
        lstm_state = tf.contrib.rnn.LSTMStateTuple(self.initial_lstm_state[0], self.initial_lstm_state[1])
        encoder_network_template = tf.make_template('vpn_encoder', self.encoder_template)
        decoder_network_template = tf.make_template('vpn_decoder', self.decoder_template)

        with tf.name_scope('training_graph'):
            net_unwrap = []
            for i in range(self.config.truncated_steps):
                encoder_state, lstm_state = encoder_network_template(self.sequences[:, i], lstm_state)
                step_out = decoder_network_template(encoder_state, self.sequences[:, i + 1])
                net_unwrap.append(step_out)

            self.final_lstm_state = lstm_state

        with tf.name_scope('wrap_out'):
            net_unwrap = tf.stack(net_unwrap)
            self.output = tf.transpose(net_unwrap, [1, 0, 2, 3, 4])

            for i in range(self.config.truncated_steps):
                Logger.summarize_images(tf.expand_dims(tf.cast(tf.arg_max(self.output[:, i], 3), tf.float32), 3),
                                        'frame_{0}'.format(i), 'vpn', 1)

        with tf.name_scope('loss'):
            labels = tf.one_hot(tf.cast(tf.squeeze(self.sequences[:, 1:]), tf.int32),
                                256,
                                axis=-1,
                                dtype=tf.float32)
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.output, labels=labels))
            self.optimizer = tf.train.RMSPropOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope('inference_graph'):
            lstm_state = tf.contrib.rnn.LSTMStateTuple(self.initial_lstm_state[0], self.initial_lstm_state[1])
            self.encoder_state, lstm_state = encoder_network_template(self.inference_prev_frame, lstm_state)
            self.inference_lstm_state = lstm_state
            self.inference_output = decoder_network_template(self.inference_encoder_state, self.inference_current_frame)

        with tf.name_scope('test_frames'):
            self.test_summaries = []
            for i in range(self.config.truncated_steps):
                Logger.summarize_images(tf.expand_dims(tf.cast(tf.arg_max(self.inference_output, 3), tf.float32), 3),
                                        'test_frame_{0}'.format(i), 'vpn_test_{0}'.format(i), 1)
                self.test_summaries.append(tf.summary.merge_all('vpn_test_{0}'.format(i)))

        self.summaries = tf.summary.merge_all('vpn')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号