def __init__(
self,
input_size,
nb_channels=3,
conditional=False,
latent_dim=10,
nb_pixelcnn_layers=13,
nb_filters=128,
filter_size_1st=(7,7),
filter_size=(3,3),
optimizer='adadelta',
es_patience=100,
save_root='/tmp/pixelcnn',
save_best_only=False,
**kwargs):
'''
Args:
input_size ((int,int)) : (height, width) pixels of input images
nb_channels (int) : Number of channels for input images. (1 for grayscale images, 3 for color images)
conditional (bool) : if True, use latent vector to model the conditional distribution p(x|h) (default:False)
latent_dim (int) : (if conditional==True,) Dimensions for latent vector.
nb_pixelcnn_layers (int) : Number of layers (except last two ReLu layers). (default:13)
nb_filters (int) : Number of filters (feature maps) for each layer. (default:128)
filter_size_1st ((int, int)): Kernel size for the first layer. (default: (7,7))
filter_size ((int, int)) : Kernel size for the subsequent layers. (default: (3,3))
optimizer (str) : SGD optimizer (default: 'adadelta')
es_patience (int) : Number of epochs with no improvement after which training will be stopped (EarlyStopping)
save_root (str) : Root directory to which {trained model file, parameter.txt, tensorboard log file} are saved
save_best_only (bool) : if True, the latest best model will not be overwritten (default: False)
'''
K.set_image_dim_ordering('tf')
self.input_size = input_size
self.conditional = conditional
self.latent_dim = latent_dim
self.nb_pixelcnn_layers = nb_pixelcnn_layers
self.nb_filters = nb_filters
self.filter_size_1st = filter_size_1st
self.filter_size = filter_size
self.nb_channels = nb_channels
if self.nb_channels == 1:
self.loss = 'binary_crossentropy'
elif self.nb_channels == 3:
self.loss = 'categorical_crossentropy'
self.optimizer = optimizer
self.es_patience = es_patience
self.save_best_only = save_best_only
tensorboard_dir = os.path.join(save_root, 'pixelcnn-tensorboard')
checkpoint_path = os.path.join(save_root, 'pixelcnn-weights.{epoch:02d}-{val_loss:.4f}.hdf5')
self.tensorboard = TensorBoard(log_dir=tensorboard_dir)
### "save_weights_only=False" causes error when exporting model architecture. (json or yaml)
self.checkpointer = ModelCheckpoint(filepath=checkpoint_path, verbose=1, save_weights_only=True, save_best_only=save_best_only)
self.earlystopping = EarlyStopping(monitor='val_loss', patience=es_patience, verbose=0, mode='auto')
评论列表
文章目录