def __init__(self, config):
self.config = config
self.global_step = tf.get_variable('global_step', shape=[], dtype='int32',
initializer=tf.constant_initializer(0), trainable=False)
# Define forward inputs here
N, M, JX, JQ, VW, VC, W, H = \
config.batch_size, config.max_num_sents, config.max_sent_size, \
config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.max_word_size, config.max_tree_height
self.x = tf.placeholder('int32', [None, M, JX], name='x')
self.cx = tf.placeholder('int32', [None, M, JX, W], name='cx')
self.q = tf.placeholder('int32', [None, JQ], name='q')
self.cq = tf.placeholder('int32', [None, JQ, W], name='cq')
self.tx = tf.placeholder('int32', [None, M, H, JX], name='tx')
self.tx_edge_mask = tf.placeholder('bool', [None, M, H, JX, JX], name='tx_edge_mask')
self.y = tf.placeholder('bool', [None, M, H, JX], name='y')
self.is_train = tf.placeholder('bool', [], name='is_train')
# Define misc
# Forward outputs / loss inputs
self.logits = None
self.yp = None
self.var_list = None
# Loss outputs
self.loss = None
self._build_forward()
self._build_loss()
self.ema_op = self._get_ema_op()
self.summary = tf.merge_all_summaries()
评论列表
文章目录