def __init__(self,
X_train_file='',
Y_train_file='',
batch_size=1,
image_size=256,
use_lsgan=True,
norm='instance',
lambda1=10.0,
lambda2=10.0,
learning_rate=2e-4,
beta1=0.5,
ngf=64
):
"""
Args:
X_train_file: string, X tfrecords file for training
Y_train_file: string Y tfrecords file for training
batch_size: integer, batch size
image_size: integer, image size
lambda1: integer, weight for forward cycle loss (X->Y->X)
lambda2: integer, weight for backward cycle loss (Y->X->Y)
use_lsgan: boolean
norm: 'instance' or 'batch'
learning_rate: float, initial learning rate for Adam
beta1: float, momentum term of Adam
ngf: number of gen filters in first conv layer
"""
self.lambda1 = lambda1
self.lambda2 = lambda2
self.use_lsgan = use_lsgan
use_sigmoid = not use_lsgan
self.batch_size = batch_size
self.image_size = image_size
self.learning_rate = learning_rate
self.beta1 = beta1
self.X_train_file = X_train_file
self.Y_train_file = Y_train_file
self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training')
self.G = Generator('G', self.is_training, ngf=ngf, norm=norm, image_size=image_size)
self.D_Y = Discriminator('D_Y',
self.is_training, norm=norm, use_sigmoid=use_sigmoid)
self.F = Generator('F', self.is_training, norm=norm, image_size=image_size)
self.D_X = Discriminator('D_X',
self.is_training, norm=norm, use_sigmoid=use_sigmoid)
self.fake_x = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3])
self.fake_y = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3])
评论列表
文章目录