def _attention(self, inputs, output_size, gene, variation, activation_fn=tf.tanh):
inputs_shape = inputs.get_shape()
if len(inputs_shape) != 3 and len(inputs_shape) != 4:
raise ValueError('Shape of input must have 3 or 4 dimensions')
input_projection = layers.fully_connected(inputs, output_size,
activation_fn=activation_fn)
doc_context = tf.concat([gene, variation], axis=1)
doc_context_vector = layers.fully_connected(doc_context, output_size,
activation_fn=activation_fn)
doc_context_vector = tf.expand_dims(doc_context_vector, 1)
if len(inputs_shape) == 4:
doc_context_vector = tf.expand_dims(doc_context_vector, 1)
vector_attn = input_projection * doc_context_vector
vector_attn = tf.reduce_sum(vector_attn, axis=-1, keep_dims=True)
attention_weights = tf.nn.softmax(vector_attn, dim=1)
weighted_projection = input_projection * attention_weights
outputs = tf.reduce_sum(weighted_projection, axis=-2)
return outputs
text_classification_model_han.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录