def build_model(self, is_train=True):
''' Build model '''
# Placeholders for data
self.current_frames = tf.placeholder(
name='current_frames', dtype=tf.float32,
shape=[self.batch_size, self.num_frames, self.image_height, self.image_width, self.num_channels]
)
self.future_frames = tf.placeholder(
name='future_frames', dtype=tf.float32,
shape=[self.batch_size, self.num_frames, self.image_height, self.image_width, self.num_channels]
)
# self.label = tf.placeholder(
# name='label', dtype=tf.float32, shape=[self.batch_size, self.num_classes]
# )
self.is_train = tf.placeholder_with_default(bool(is_train), [], name='is_train')
# Encoder
self.E = Encoder('Encoder', self.configs_encoder)
self.z = self.E(self.current_frames, is_debug=self.is_debug)
# Generators
self.Gr = Generator('Generator_R', self.configs_generator)
self.Gf = Generator('Generator_F', self.configs_generator)
self.generated_current_frames = self.Gr(self.z, is_debug=self.is_debug)
self.generated_future_frames = self.Gf(self.z, is_debug=self.is_debug)
# Discriminators
self.D = Discriminator('Discriminator', self.configs_discriminator)
self.D_real_current, self.D_real_current_logits = self.D(self.current_frames, is_debug=self.is_debug)
self.D_fake_current, self.D_fake_current_logits = self.D(self.generated_current_frames, is_debug=self.is_debug)
self.D_real_future, self.D_real_future_logits = self.D(self.future_frames, is_debug=self.is_debug)
self.D_fake_future, self.D_fake_future_logits = self.D(self.generated_future_frames, is_debug=self.is_debug)
print_message('Successfully loaded the model')
评论列表
文章目录