GAN.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:2017_iv_deep_radar 作者: tawheeler 项目源码 文件源码
def train_generator(nsteps):
        mean_loss = 0.0
        for i in range(1,nsteps):
            batch_indeces = np.random.randint(0,O_train.shape[0],args.batch_size)
            o_in = O_train[batch_indeces,:,:,:]
            t_in = T_train[batch_indeces,:,:,:]
            y_in = Y_train[batch_indeces,:,:,:]
            r = generator.fit([o_in,t_in,y_in], [y_in, d_comb],
                             #callbacks=[TensorBoard(log_dir=args.tblog + '_G', write_graph=False)],
                             verbose=0)
            loss = r.history['loss'][0]
            mean_loss = mean_loss + loss
        return mean_loss / nsteps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号