def get_mixture_coef( self, args, output ):
# returns the tf slices containing mdn dist params
# ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850
z = output
#get the remaining parameters
last = args.nroutputvars_raw - args.nrClassOutputVars
z_eos = z[ :, 0 ]
z_eos = tf.sigmoid( z_eos ) #eos: sigmoid, eq 18
z_eod = z[ :, 1 ]
z_eod = tf.sigmoid( z_eod ) #eod: sigmoid
z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split( z[ :, 2:last ], 6, 1 ) #eq 20: mu1, mu2: no transformation required
# process output z's into MDN parameters
# softmax all the pi's:
max_pi = tf.reduce_max( z_pi, 1, keep_dims = True )
z_pi = tf.subtract( z_pi, max_pi ) #EdJ: subtract max pi for numerical stabilization
z_pi = tf.exp( z_pi ) #eq 19
normalize_pi = tf.reciprocal( tf.reduce_sum( z_pi, 1, keep_dims = True ) )
z_pi = tf.multiply( normalize_pi, z_pi ) #19
# exponentiate the sigmas and also make corr between -1 and 1.
z_sigma1 = tf.exp( z_sigma1 ) #eq 21
z_sigma2 = tf.exp( z_sigma2 )
z_corr_tanh = tf.tanh( z_corr ) #eq 22
z_corr_tanh = .95 * z_corr_tanh #avoid -1 and 1
z_corr_tanh_adj = z_corr_tanh
return [ z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr_tanh_adj, z_eos, z_eod ]
model.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录