def sharp_weights(self,after_conv_shift, sharp_gamma):
"""
Sharpens the final weights
Parameters:
----------
after_conv_shift: Tensor (batch_size, memory_locations, number_of_keys)
weights after circular Convolution
sharp_gamma: Tensor (batch_size, number_of_keys)
scalar to sharpen the final weights
Returns: Tensor (batch_size, memory_locations, number_of_keys)
final weights
"""
sharp_gamma = tf.expand_dims(sharp_gamma,1)
powed_conv_w = tf.pow(after_conv_shift, sharp_gamma)
return powed_conv_w / tf.expand_dims(tf.reduce_sum(powed_conv_w,1),1)
评论列表
文章目录