Gan.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ICGan-tensorflow 作者: zhangqianhui 项目源码 文件源码
def train_ez(self):

        opti_EZ = tf.train.AdamOptimizer(learning_rate = 0.01, beta1 = 0.5).minimize(self.loss_z,
                                                                                      var_list=self.enz_vars)
        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            sess.run(init)
            #summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

            self.saver.restore(sess , self.model_path)

            batch_num = 0
            e = 0
            step = 0

            while e <= self.max_epoch:

                rand = np.random.randint(0, 100)
                rand = 0

                while batch_num < len(self.ds_train) / self.batch_size:

                    step = step + 1

                    _,label_y = MnistData.getNextBatch(self.ds_train, self.label_y, rand, batch_num,
                                                                     self.batch_size)
                    batch_z = np.random.normal(0, 1, size=[self.batch_size, self.sample_size])

                    # optimization E
                    sess.run(opti_EZ, feed_dict={self.y: label_y,self.z: batch_z})
                    batch_num += 1

                    if step % 10 == 0:

                        ez_loss = sess.run(self.loss_z, feed_dict={self.y: label_y,self.z: batch_z})
                        #summary_writer.add_summary(ez_loss, step)
                        print("EPOCH %d step %d EZ loss %.7f" % (e, step, ez_loss))

                    if np.mod(step, 50) == 0:

                        # sample_images = sess.run(self.fake_images, feed_dict={self.e_y:})
                        # save_images(sample_images[0:64], [8, 8],
                        #             './{}/train_{:02d}_{:04d}.png'.format(self.sample_path, e, step))
                        self.saver_z.save(sess, self.encode_z_model)

                e += 1
                batch_num = 0

            save_path = self.saver_z.save(sess, self.encode_z_model)
            print "Model saved in file: %s" % save_path
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号