def main():
# preparations
create_checkpoints_dir()
utils.download_train_and_test_data()
trainset, testset = utils.load_data_sets()
# create real input for the GAN model (its dicriminator) and
# GAN model itself
real_size = (32, 32, 3)
z_size = 100
learning_rate = 0.0003
tf.reset_default_graph()
input_real = tf.placeholder(tf.float32, (None, *real_size), name='input_real')
net = GAN(input_real, z_size, learning_rate)
# craete dataset
dataset = Dataset(trainset, testset)
# train the model
batch_size = 128
epochs = 25
_, _, _ = train(net, dataset, epochs, batch_size, z_size)
svnh_semi_supervised_model_train.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录