model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DeepVideo 作者: AniketBajpai 项目源码 文件源码
def build_model(self, is_train=True):
        ''' Build model '''

        # Placeholders for data
        self.current_frames = tf.placeholder(
            name='current_frames', dtype=tf.float32,
            shape=[self.batch_size, self.num_frames, self.image_height, self.image_width, self.num_channels]
        )
        self.future_frames = tf.placeholder(
            name='future_frames', dtype=tf.float32,
            shape=[self.batch_size, self.num_frames, self.image_height, self.image_width, self.num_channels]
        )
        # self.label = tf.placeholder(
        #     name='label', dtype=tf.float32, shape=[self.batch_size, self.num_classes]
        # )

        self.is_train = tf.placeholder_with_default(bool(is_train), [], name='is_train')

        # Encoder
        self.E = Encoder('Encoder', self.configs_encoder)
        self.z = self.E(self.current_frames, is_debug=self.is_debug)

        # Generators
        self.Gr = Generator('Generator_R', self.configs_generator)
        self.Gf = Generator('Generator_F', self.configs_generator)

        self.generated_current_frames = self.Gr(self.z, is_debug=self.is_debug)
        self.generated_future_frames = self.Gf(self.z, is_debug=self.is_debug)

        # Discriminators
        self.D = Discriminator('Discriminator', self.configs_discriminator)

        self.D_real_current, self.D_real_current_logits = self.D(self.current_frames, is_debug=self.is_debug)
        self.D_fake_current, self.D_fake_current_logits = self.D(self.generated_current_frames, is_debug=self.is_debug)
        self.D_real_future, self.D_real_future_logits = self.D(self.future_frames, is_debug=self.is_debug)
        self.D_fake_future, self.D_fake_future_logits = self.D(self.generated_future_frames, is_debug=self.is_debug)

        print_message('Successfully loaded the model')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号