def check_params(self):
if self.interaction not in DSSA.VALID_INTERACTION:
raise ValueError('interaction not valid, it should be one of {}'
.format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_INTERACTION))))
if self.cell_type not in DSSA.VALID_CELL_TYPE:
raise ValueError('cell_type not valid, it should be one of {}'
.format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_CELL_TYPE))))
if not isinstance(self.doc_emb, np.ndarray) or not isinstance(self.query_emb, np.ndarray):
raise ValueError('both doc_emb and query_emb should by instance of numpy.ndarray')
self.doc_emb_actual_size = self.n_rel_feat * self.most_n_subquery + self.n_doc_emb
if self.doc_emb.shape[1] != self.doc_emb_actual_size:
raise ValueError('doc_emb shape[1] is unexpected. {} is desired while we got {}'
.format(self.doc_emb_actual_size, self.doc_emb.shape[1]))
self.query_emb_actual_size = self.n_query_emb + 1
if self.query_emb.shape[1] != self.query_emb_actual_size:
raise ValueError('query_emb shape[1] is unexpected. {} is desired while we got {}'
.format(self.query_emb_actual_size, self.query_emb.shape[1]))
if self.optimization not in DSSA.VALID_OPTIMIZATION:
raise ValueError('optimization not valid, it should be one of {}'
.format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_OPTIMIZATION))))
self.input_dim = 1 + self.most_n_subquery
self.expand_input_dim = \
self.n_rel_feat * self.most_n_subquery + self.n_doc_emb + (self.n_query_emb + 1) * self.most_n_subquery
if self.reuse_model and not hasattr(self, 'session_'): # read model from file
self.graph_ = tf.Graph()
with self.graph_.as_default():
tf.set_random_seed(self.random_seed)
with vs.variable_scope('DSSA', initializer=
tf.uniform_unit_scaling_initializer(seed=self.random_seed)) as scope:
self.build_graph()
scope.reuse_variables()
self.build_graph_test()
self.session_ = tf.Session(graph=self.graph_)
print('load model from "{}"'.format(self.reuse_model))
self.saver.restore(self.session_, self.reuse_model)
评论列表
文章目录