text_classification_model_han.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def _attention(self, inputs, output_size, gene, variation, activation_fn=tf.tanh):
        inputs_shape = inputs.get_shape()
        if len(inputs_shape) != 3 and len(inputs_shape) != 4:
            raise ValueError('Shape of input must have 3 or 4 dimensions')
        input_projection = layers.fully_connected(inputs, output_size,
                                                  activation_fn=activation_fn)
        doc_context = tf.concat([gene, variation], axis=1)
        doc_context_vector = layers.fully_connected(doc_context, output_size,
                                                    activation_fn=activation_fn)
        doc_context_vector = tf.expand_dims(doc_context_vector, 1)
        if len(inputs_shape) == 4:
            doc_context_vector = tf.expand_dims(doc_context_vector, 1)

        vector_attn = input_projection * doc_context_vector
        vector_attn = tf.reduce_sum(vector_attn, axis=-1, keep_dims=True)
        attention_weights = tf.nn.softmax(vector_attn, dim=1)
        weighted_projection = input_projection * attention_weights
        outputs = tf.reduce_sum(weighted_projection, axis=-2)

        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号