def compute_energy(hidden, state, attn_size, attn_keep_prob=None, pervasive_dropout=False, layer_norm=False,
mult_attn=False, **kwargs):
if attn_keep_prob is not None:
state_noise_shape = [1, tf.shape(state)[1]] if pervasive_dropout else None
state = tf.nn.dropout(state, keep_prob=attn_keep_prob, noise_shape=state_noise_shape)
hidden_noise_shape = [1, 1, tf.shape(hidden)[2]] if pervasive_dropout else None
hidden = tf.nn.dropout(hidden, keep_prob=attn_keep_prob, noise_shape=hidden_noise_shape)
if mult_attn:
state = dense(state, attn_size, use_bias=False, name='state')
hidden = dense(hidden, attn_size, use_bias=False, name='hidden')
return tf.einsum('ijk,ik->ij', hidden, state)
else:
y = dense(state, attn_size, use_bias=not layer_norm, name='W_a')
y = tf.expand_dims(y, axis=1)
if layer_norm:
y = tf.contrib.layers.layer_norm(y, scope='layer_norm_state')
hidden = tf.contrib.layers.layer_norm(hidden, center=False, scope='layer_norm_hidden')
f = dense(hidden, attn_size, use_bias=False, name='U_a')
v = get_variable('v_a', [attn_size])
s = f + y
return tf.reduce_sum(v * tf.tanh(s), axis=2)
评论列表
文章目录