def build_model(self):
lstm_state = tf.contrib.rnn.LSTMStateTuple(self.initial_lstm_state[0], self.initial_lstm_state[1])
encoder_network_template = tf.make_template('vpn_encoder', self.encoder_template)
decoder_network_template = tf.make_template('vpn_decoder', self.decoder_template)
with tf.name_scope('training_graph'):
net_unwrap = []
for i in range(self.config.truncated_steps):
encoder_state, lstm_state = encoder_network_template(self.sequences[:, i], lstm_state)
step_out = decoder_network_template(encoder_state, self.sequences[:, i + 1])
net_unwrap.append(step_out)
self.final_lstm_state = lstm_state
with tf.name_scope('wrap_out'):
net_unwrap = tf.stack(net_unwrap)
self.output = tf.transpose(net_unwrap, [1, 0, 2, 3, 4])
for i in range(self.config.truncated_steps):
Logger.summarize_images(tf.expand_dims(tf.cast(tf.arg_max(self.output[:, i], 3), tf.float32), 3),
'frame_{0}'.format(i), 'vpn', 1)
with tf.name_scope('loss'):
labels = tf.one_hot(tf.cast(tf.squeeze(self.sequences[:, 1:]), tf.int32),
256,
axis=-1,
dtype=tf.float32)
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.output, labels=labels))
self.optimizer = tf.train.RMSPropOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
with tf.name_scope('inference_graph'):
lstm_state = tf.contrib.rnn.LSTMStateTuple(self.initial_lstm_state[0], self.initial_lstm_state[1])
self.encoder_state, lstm_state = encoder_network_template(self.inference_prev_frame, lstm_state)
self.inference_lstm_state = lstm_state
self.inference_output = decoder_network_template(self.inference_encoder_state, self.inference_current_frame)
with tf.name_scope('test_frames'):
self.test_summaries = []
for i in range(self.config.truncated_steps):
Logger.summarize_images(tf.expand_dims(tf.cast(tf.arg_max(self.inference_output, 3), tf.float32), 3),
'test_frame_{0}'.format(i), 'vpn_test_{0}'.format(i), 1)
self.test_summaries.append(tf.summary.merge_all('vpn_test_{0}'.format(i)))
self.summaries = tf.summary.merge_all('vpn')
评论列表
文章目录