def __init__(self, config, scope):
self.scope = scope
self.config = config
self.global_step = tf.get_variable('global_step', shape=[], dtype='int32',
initializer=tf.constant_initializer(0), trainable=False)
# Define forward inputs here
N, M, JX, JQ, VW, VC, W = \
config.batch_size, config.max_num_sents, config.max_sent_size, \
config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.max_word_size
self.x = tf.placeholder('int32', [N, M, None], name='x')
self.cx = tf.placeholder('int32', [N, M, None, W], name='cx')
self.x_mask = tf.placeholder('bool', [N, M, None], name='x_mask')
self.q = tf.placeholder('int32', [N, JQ], name='q')
self.cq = tf.placeholder('int32', [N, JQ, W], name='cq')
self.q_mask = tf.placeholder('bool', [N, JQ], name='q_mask')
self.y = tf.placeholder('bool', [N, M, JX], name='y')
self.is_train = tf.placeholder('bool', [], name='is_train')
self.new_emb_mat = tf.placeholder('float', [None, config.word_emb_size], name='new_emb_mat')
# Define misc
self.tensor_dict = {}
# Forward outputs / loss inputs
self.logits = None
self.yp = None
self.var_list = None
# Loss outputs
self.loss = None
self._build_forward()
self._build_loss()
if config.mode == 'train':
self._build_ema()
self.summary = tf.merge_all_summaries()
self.summary = tf.merge_summary(tf.get_collection("summaries", scope=self.scope))
评论列表
文章目录