def inner_fn_sample(stm1):
prior_stmu = T.tanh( T.dot(Wl_stmu_stm1, stm1) + bl_stmu )
prior_stsig = T.nnet.softplus( T.dot(Wl_stsig_stm1, stm1) + bl_stsig ) + sig_min_states
# Set explicit prior on score during last time step
#prior_stmu = ifelse(T.lt(t,n_run_steps - 5),prior_stmu, T.set_subtensor(prior_stmu[0,:],0.1))
#prior_stsig = ifelse(T.lt(t,n_run_steps - 5),prior_stsig, T.set_subtensor(prior_stsig[0,:],0.001))
st = prior_stmu + theano_rng.normal((n_s,n_samples))*prior_stsig
ost = T.nnet.relu( T.dot(Wl_ost_st,st) + bl_ost )
ost2 = T.nnet.relu( T.dot(Wl_ost2_ost,ost) + bl_ost2 )
ost3 = T.nnet.relu( T.dot(Wl_ost3_ost2,ost2) + bl_ost3 )
otmu = T.dot(Wl_otmu_st, ost3) + bl_otmu
otsig = T.nnet.softplus(T.dot(Wl_otsig_st, ost3) + bl_otsig) + sig_min_obs
ohtmu = T.dot(Wl_ohtmu_st, ost3) + bl_ohtmu
ohtsig = T.nnet.softplus( T.dot(Wl_ohtsig_st, ost3) + bl_ohtsig ) + sig_min_obs
oatmu = T.dot(Wl_oatmu_st, ost3) + bl_oatmu
oatsig = T.nnet.softplus( T.dot(Wl_oatsig_st, ost3) + bl_oatsig ) + sig_min_obs
ot = otmu + theano_rng.normal((n_o,n_samples))*otsig
oht = ohtmu + theano_rng.normal((n_oh,n_samples))*ohtsig
oat = oatmu + theano_rng.normal((n_oa,n_samples))*oatsig
return st, ohtmu, ohtsig, ot, oht, oat, prior_stmu, prior_stsig
# Define initial state and action
评论列表
文章目录