def calc_tvd(sess,Generator,Data,N=50000,nbins=10):
Xd=sess.run(Data.X,{Data.N:N})
step,Xg=sess.run([Generator.step,Generator.X],{Generator.N:N})
p_gen,_ = np.histogramdd(Xg,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True)
p_dat,_ = np.histogramdd(Xd,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True)
p_gen/=nbins**3
p_dat/=nbins**3
tvd=0.5*np.sum(np.abs( p_gen-p_dat ))
mvd=np.max(np.abs( p_gen-p_dat ))
return step,tvd, mvd
s_tvd=make_summary(Data.name+'_tvd',tvd)
s_mvd=make_summary(Data.name+'_mvd',mvd)
return step,s_tvd,s_mvd
#return make_summary('tvd/'+Generator.name,tvd)
评论列表
文章目录