def _partition_and_stitch(self, args, func_name):
"""
args is a list of tensors, to be passed to self.likelihoods.<func_name>
args[-1] is the 'Y' argument, which contains the indexes to self.likelihoods.
This function splits up the args using dynamic_partition, calls the
relevant function on the likelihoods, and re-combines the result.
"""
# get the index from Y
Y = args[-1]
ind = tf.gather(tf.transpose(Y), tf.shape(Y)[1]-1) # ind = Y[:,-1]
ind = tf.cast(ind, tf.int32)
Y = tf.transpose(tf.gather(tf.transpose(Y), tf.range(0, tf.shape(Y)[1]-1))) # Y = Y[:,:-1]
args[-1] = Y
# split up the arguments into chunks corresponding to the relevant likelihoods
args = zip(*[tf.dynamic_partition(X, ind, self.num_likelihoods) for X in args])
# apply the likelihood-function to each section of the data
with params_as_tensors_for(self, convert=False):
funcs = [getattr(lik, func_name) for lik in self.likelihood_list]
results = [f(*args_i) for f, args_i in zip(funcs, args)]
# stitch the results back together
partitions = tf.dynamic_partition(tf.range(0, tf.size(ind)), ind, self.num_likelihoods)
results = tf.dynamic_stitch(partitions, results)
return results
评论列表
文章目录