cgan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:shenlan 作者: vector-1127 项目源码 文件源码
def train(BATCH_SIZE):
    (X_train, Y_train) = get_data('train')
    #print(np.shape(X_train))
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    Y_train = (Y_train.astype(np.float32) - 127.5)/127.5
    #X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    #Y_train = Y_train.reshape((Y_train.shape[0], 1) + Y_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    generator.summary()
    discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
    d_optim = Adagrad(lr=0.005)
    g_optim = Adagrad(lr=0.005)
    generator.compile(loss='mse', optimizer="rmsprop")
    discriminator_on_generator.compile(loss=[generator_l1_loss,discriminator_on_generator_loss] , optimizer="rmsprop")
    discriminator.trainable = True
    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            image_batch = Y_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]

            generated_images = generator.predict(X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE])
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                image = np.swapaxes(image,0,2)
                cv2.imwrite(str(epoch)+"_"+str(index)+".png",image)      
                #Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")

            real_pairs = np.concatenate((X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:],image_batch),axis=1) 
            fake_pairs = np.concatenate((X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:],generated_images),axis=1)
            X = np.concatenate((real_pairs,fake_pairs))
            y = np.zeros((20,1,64,64)) #[1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            pred_temp = discriminator.predict(X)
            #print(np.shape(pred_temp))
            print("batch %d d_loss : %f" % (index, d_loss))
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:], [image_batch,np.ones((10,1,64,64))] )
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号