def get_normalized_gamma(size, filter_height, filter_width):
"""Get normalized gamma.
Args:
size: [B, T, 2] or [B, 2] or [2]
filter_height: int
filter_width: int
Returns:
lg_gamma: [B, T] or [B] or float
"""
rank = tf.shape(tf.shape(size))
filter_area = filter_height * filter_width
area = tf.reduce_prod(size, rank - 1)
lg_gamma = tf.log(float(filter_area)) - tf.log(area)
return lg_gamma
评论列表
文章目录