GAN.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:2017_iv_deep_radar 作者: tawheeler 项目源码 文件源码
def train_discriminator(nsteps):
        mean_loss = 0.0
        for i in range(1,nsteps):
            # pick real samples
            batch_indeces = np.random.randint(0,O_train.shape[0],args.batch_size)
            y_real = Y_train[batch_indeces,:,:,:]

            # pick fake samples
            batch_indeces = np.random.randint(0,O_train.shape[0],args.batch_size)
            o_in = O_train[batch_indeces,:,:,:]
            t_in = T_train[batch_indeces,:,:,:]
            y_in = Y_train[batch_indeces,:,:,:]
            y_fake = generator.predict([o_in, t_in, y_in])[0]

            # train
            y_disc = np.vstack([y_real, y_fake])
            r = adversary.fit(y_disc, d_disc,
                              #callbacks=[TensorBoard(log_dir=args.tblog + '_D', write_graph=False)],
                              verbose=0)
            loss = r.history['loss'][0]
            mean_loss = mean_loss + loss
        return mean_loss / nsteps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号