dssa.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:DSSA 作者: jzbjyb 项目源码 文件源码
def check_params(self):
        if self.interaction not in DSSA.VALID_INTERACTION:
            raise ValueError('interaction not valid, it should be one of {}'
                             .format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_INTERACTION))))
        if self.cell_type not in DSSA.VALID_CELL_TYPE:
            raise ValueError('cell_type not valid, it should be one of {}'
                             .format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_CELL_TYPE))))
        if not isinstance(self.doc_emb, np.ndarray) or not isinstance(self.query_emb, np.ndarray):
            raise ValueError('both doc_emb and query_emb should by instance of numpy.ndarray')
        self.doc_emb_actual_size = self.n_rel_feat * self.most_n_subquery + self.n_doc_emb
        if self.doc_emb.shape[1] != self.doc_emb_actual_size:
            raise ValueError('doc_emb shape[1] is unexpected. {} is desired while we got {}'
                             .format(self.doc_emb_actual_size, self.doc_emb.shape[1]))
        self.query_emb_actual_size = self.n_query_emb + 1
        if self.query_emb.shape[1] != self.query_emb_actual_size:
            raise ValueError('query_emb shape[1] is unexpected. {} is desired while we got {}'
                             .format(self.query_emb_actual_size, self.query_emb.shape[1]))
        if self.optimization not in DSSA.VALID_OPTIMIZATION:
            raise ValueError('optimization not valid, it should be one of {}'
                             .format(', '.join(map(lambda x: '"' + x + '"', DSSA.VALID_OPTIMIZATION))))
        self.input_dim = 1 + self.most_n_subquery
        self.expand_input_dim = \
            self.n_rel_feat * self.most_n_subquery + self.n_doc_emb + (self.n_query_emb + 1) * self.most_n_subquery
        if self.reuse_model and not hasattr(self, 'session_'): # read model from file
            self.graph_ = tf.Graph()
            with self.graph_.as_default():
                tf.set_random_seed(self.random_seed)
                with vs.variable_scope('DSSA', initializer=
                tf.uniform_unit_scaling_initializer(seed=self.random_seed)) as scope:
                    self.build_graph()
                    scope.reuse_variables()
                    self.build_graph_test()
            self.session_ = tf.Session(graph=self.graph_)
            print('load model from "{}"'.format(self.reuse_model))
            self.saver.restore(self.session_, self.reuse_model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号