def get_GP_samples(Y,T,X,ind_kf,ind_kt,num_obs_times,num_obs_values,
num_rnn_grid_times,med_cov_grid):
"""
returns samples from GP at evenly-spaced gridpoints
"""
grid_max = tf.shape(X)[1]
Z = tf.zeros([0,grid_max,input_dim])
N = tf.shape(T)[0] #number of observations
#setup tf while loop (have to use this bc loop size is variable)
def cond(i,Z):
return i<N
def body(i,Z):
Yi = tf.reshape(tf.slice(Y,[i,0],[1,num_obs_values[i]]),[-1])
Ti = tf.reshape(tf.slice(T,[i,0],[1,num_obs_times[i]]),[-1])
ind_kfi = tf.reshape(tf.slice(ind_kf,[i,0],[1,num_obs_values[i]]),[-1])
ind_kti = tf.reshape(tf.slice(ind_kt,[i,0],[1,num_obs_values[i]]),[-1])
Xi = tf.reshape(tf.slice(X,[i,0],[1,num_rnn_grid_times[i]]),[-1])
X_len = num_rnn_grid_times[i]
GP_draws = draw_GP(Yi,Ti,Xi,ind_kfi,ind_kti)
pad_len = grid_max-X_len #pad by this much
padded_GP_draws = tf.concat([GP_draws,tf.zeros((n_mc_smps,pad_len,M))],1)
medcovs = tf.slice(med_cov_grid,[i,0,0],[1,-1,-1])
tiled_medcovs = tf.tile(medcovs,[n_mc_smps,1,1])
padded_GPdraws_medcovs = tf.concat([padded_GP_draws,tiled_medcovs],2)
Z = tf.concat([Z,padded_GPdraws_medcovs],0)
return i+1,Z
i = tf.constant(0)
i,Z = tf.while_loop(cond,body,loop_vars=[i,Z],
shape_invariants=[i.get_shape(),tf.TensorShape([None,None,None])])
return Z
评论列表
文章目录