humancritic_tensorflow.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:LearningFromHumanPreferences 作者: ZachisGit 项目源码 文件源码
def init_tf(self):
        tf.reset_default_graph()
        self.graph =tf.Graph()
        with self.graph.as_default():
            self.initializer = tf.truncated_normal_initializer(stddev=0.3)

            self.input_o0 = tf.placeholder(shape=[None,self.obs_size],dtype=tf.float32)
            self.input_o1 = tf.placeholder(shape=[None,self.obs_size],dtype=tf.float32)
            self.preference_distribution = tf.placeholder(shape=[2],dtype=tf.float32)
            self.model_o0 = self.create_model(self.input_o0)
            self.model_o1 = self.create_model(self.input_o1,reuse=True)
            self.batch_sizes = tf.placeholder(shape=[2],dtype=tf.float32)
            #'''
            self.model_o0_sum = tf.exp(tf.divide(tf.reduce_sum(self.model_o0),self.batch_sizes[0]))
            self.model_o1_sum = tf.exp(tf.divide(tf.reduce_sum(self.model_o1),self.batch_sizes[1]))
            #self.model_o1_sum = tf.exp(tf.reduce_sum(self.model_o1))
            self.p_o0_o1 = tf.divide(self.model_o0_sum,tf.add(self.model_o0_sum,self.model_o1_sum))
            self.p_o1_o0 = tf.divide(self.model_o1_sum,tf.add(self.model_o1_sum,self.model_o0_sum))
            self.loss = -tf.add(tf.multiply(self.preference_distribution[0],tf.log(self.p_o0_o1)), \
                    tf.multiply(self.preference_distribution[1],tf.log(self.p_o1_o0)))

            '''
            self.model_o0_sum = tf.exp(tf.reduce_sum(self.model_o0))
            self.model_o1_sum = tf.exp(tf.reduce_sum(self.model_o1))
            self.p_o0_o1 = tf.add(1e-5,tf.divide(self.model_o0_sum,tf.add(1e-5,tf.add(self.model_o0_sum,self.model_o1_sum))))
            self.p_o1_o0 = tf.add(1e-5,tf.divide(self.model_o1_sum,tf.add(1e-5,tf.add(self.model_o1_sum,self.model_o0_sum))))
            self.loss = tf.add(1e-5,-tf.add(tf.multiply(self.preference_distribution[0],tf.log(self.p_o0_o1)), \
                    tf.multiply(self.preference_distribution[1],tf.log(self.p_o1_o0))))
            #'''
            self.train_step = tf.train.AdamOptimizer(learning_rate=self.LEARNING_RATE).minimize(self.loss)
            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())

            self.saver = tf.train.Saver(tf.global_variables())
            self.checkpoint_path = "./human_critic/hc_model/"+self.datetime_str+"/hc_model_"+self.datetime_str+".ckpt"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号