def likelihood_classification(w, n_classes, n_samples): # w has shape () w = tf.reshape(w, [n_classes, n_samples]) ll = predictive_ll(w) return ll # return tf.arg_max(ll, 0)