def a_high_classifier(self, page_batch, low_classifier):
"""high level classifier."""
target_batch, un_batch, un_len, la_batch, la_len = page_batch
with tf.variable_scope("low_classifier") as low_scope:
# [batch_size, 1, html_len, we_dim]
target_exp = tf.expand_dims(target_batch, 1)
# [batch_size, 1, num_cats]
target_logits = tf.map_fn(low_classifier,
target_exp,
name="map_fn")
# reuse parameters for low_classifier
low_scope.reuse_variables()
un_rel = tf.sparse_tensor_to_dense(un_batch)
un_rel = tf.reshape(un_rel, [FLAGS.batch_size, -1, FLAGS.html_len,
FLAGS.we_dim])
# call low_classifier to classify relatives
# all relatives of one target composed of one batch
# [batch_size, num_len(variant), num_cats]
un_rel = tf.map_fn(low_classifier, un_rel, name="map_fn")
# labeled relatives
la_rel = tf.sparse_tensor_to_dense(la_batch)
la_rel = tf.reshape(la_rel, [FLAGS.batch_size, -1, FLAGS.num_cats])
# concat all inputs for high-level classifier RNN
# concat_inputs = tf.concat(1, [un_rel, target_logits])
concat_inputs = tf.concat(1, [un_rel, la_rel, target_logits])
# number of pages for each target
num_pages = tf.add(
tf.add(un_len, la_len),
tf.ones(
[FLAGS.batch_size],
dtype=tf.int32))
# high-level classifier - RNN
with tf.variable_scope("dynamic_rnn"):
cell = tf.nn.rnn_cell.GRUCell(num_units=FLAGS.num_cats)
outputs, state = tf.nn.dynamic_rnn(cell,
inputs=concat_inputs,
sequence_length=num_pages,
dtype=tf.float32)
return state
评论列表
文章目录