svnh_semi_supervised_model_train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:tf_serving_example 作者: Vetal1977 项目源码 文件源码
def main():
    # preparations
    create_checkpoints_dir()
    utils.download_train_and_test_data()
    trainset, testset = utils.load_data_sets()

    # create real input for the GAN model (its dicriminator) and
    # GAN model itself
    real_size = (32, 32, 3)
    z_size = 100
    learning_rate = 0.0003

    tf.reset_default_graph()
    input_real = tf.placeholder(tf.float32, (None, *real_size), name='input_real')
    net = GAN(input_real, z_size, learning_rate)

    # craete dataset
    dataset = Dataset(trainset, testset)

    # train the model
    batch_size = 128
    epochs = 25
    _, _, _ = train(net, dataset, epochs, batch_size, z_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号