model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:web_page_classification 作者: yuhui-lin 项目源码 文件源码
def a_high_classifier(self, page_batch, low_classifier):
        """high level classifier."""
        target_batch, un_batch, un_len, la_batch, la_len = page_batch

        with tf.variable_scope("low_classifier") as low_scope:
            # [batch_size, 1, html_len, we_dim]
            target_exp = tf.expand_dims(target_batch, 1)
            # [batch_size, 1, num_cats]
            target_logits = tf.map_fn(low_classifier,
                                      target_exp,
                                      name="map_fn")

            # reuse parameters for low_classifier
            low_scope.reuse_variables()

            un_rel = tf.sparse_tensor_to_dense(un_batch)
            un_rel = tf.reshape(un_rel, [FLAGS.batch_size, -1, FLAGS.html_len,
                                         FLAGS.we_dim])
            # call low_classifier to classify relatives
            # all relatives of one target composed of one batch
            # [batch_size, num_len(variant), num_cats]
            un_rel = tf.map_fn(low_classifier, un_rel, name="map_fn")

        # labeled relatives
        la_rel = tf.sparse_tensor_to_dense(la_batch)
        la_rel = tf.reshape(la_rel, [FLAGS.batch_size, -1, FLAGS.num_cats])

        # concat all inputs for high-level classifier RNN
        # concat_inputs = tf.concat(1, [un_rel, target_logits])
        concat_inputs = tf.concat(1, [un_rel, la_rel, target_logits])

        # number of pages for each target
        num_pages = tf.add(
            tf.add(un_len, la_len),
            tf.ones(
                [FLAGS.batch_size],
                dtype=tf.int32))

        # high-level classifier - RNN
        with tf.variable_scope("dynamic_rnn"):
            cell = tf.nn.rnn_cell.GRUCell(num_units=FLAGS.num_cats)
            outputs, state = tf.nn.dynamic_rnn(cell,
                                               inputs=concat_inputs,
                                               sequence_length=num_pages,
                                               dtype=tf.float32)

        return state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号