def predict_snps(y, cut_off_prob=0.5, already_split=False):
"""Predicts which snps are causing epistasis based on one epoch and how many snps to detect.
Arguments:
y: the given output tensor
cut_off_prob: float describing the cutoff probability for a snp to be described as predicted to cause.
Recommended Values:
0.5 for 2-classifier model
0.98 for 1-classifier model
Returns:
predicted_snps: a tensor with the indices of the predicted snps
"""
with tf.name_scope('snp_prediction'):
if not already_split:
y_left = get_causing_epi_probs(y)
else:
y_left = y
y_left_t = tf.transpose(y_left, [0, 2, 1])
top_snps = tf.where(tf.greater_equal(y_left, cut_off_prob))
_, top_snp_indices, _ = tf.split(1, 3, top_snps, name='split')
top_snp_indices = tf.reshape(top_snp_indices, [-1])
top_pred_snps, _, count = tf.unique_with_counts(top_snp_indices)
return top_pred_snps, count
评论列表
文章目录