def xavier_normal_dist_conv3d(shape): return tf.truncated_normal(shape, mean=0, stddev=tf.sqrt(3. / (tf.reduce_prod(shape[:3]) * tf.reduce_sum(shape[3:]))))