fun_graph.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def build_graph(self, goal, critic):
        self.ph_stc_diff_st =\
            graph.Placeholder(np.float32, shape=(None, cfg.d), name="ph_stc_diff_st")
        s_diff_normalized = tf.nn.l2_normalize(self.ph_stc_diff_st.node, dim=1)

        cosine_similarity = tf.matmul(s_diff_normalized, goal.node, transpose_b=True)
        cosine_similarity = tf.diag_part(cosine_similarity)

        # manager's advantage (R-V): R = ri + cfg.wGAMMA * R; AdvM = R - ViM
        self.ph_discounted_reward =\
            graph.Placeholder(np.float32, shape=(None,), name="ph_m_discounted_reward")
        advantage = self.ph_discounted_reward.node - critic.node

        manager_loss = tf.reduce_sum(advantage * cosine_similarity)
        return manager_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号