ranknet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tfranknet 作者: mzhang001 项目源码 文件源码
def _setup_base_graph(self):
        """
        Set up queue, variables and session
        """
        self.graph = tf.Graph()
        with self.graph.as_default() as g:
            input_dim = self.input_dim
            batch_size = self.batch_size
            hidden_units = self.hidden_units
            layer_units = [input_dim] + hidden_units + [1]
            layer_num = len(layer_units)

            #make Queue for getting batch
            self.queue = q = tf.RandomShuffleQueue(capacity=self.q_capacity,
                                        min_after_dequeue=self.min_after_dequeue,
                                        dtypes=["float", "float"],
                                        shapes=[[input_dim], [input_dim]])
            #input data
            self.data1, self.data2 = q.dequeue_many(batch_size, name="inputs")

            self._setup_variables()
            self._setup_training()
            self._setup_prediction()
            self._setup_pretraining()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号