tf_util.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:distributional_perspective_on_RL 作者: Kiwoo 项目源码 文件源码
def __call__(self, *inputvals):
        assert len(inputvals) == len(self.nondata_inputs) + len(self.data_inputs)
        nondata_vals = inputvals[0:len(self.nondata_inputs)]
        data_vals = inputvals[len(self.nondata_inputs):]
        feed_dict = dict(zip(self.nondata_inputs, nondata_vals))
        n = data_vals[0].shape[0]
        for v in data_vals[1:]:
            assert v.shape[0] == n
        for i_start in range(0, n, self.batch_size):
            slice_vals = [v[i_start:min(i_start+self.batch_size, n)] for v in data_vals]
            for (var,val) in zip(self.data_inputs, slice_vals):
                feed_dict[var]=val
            results = tf.get_default_session().run(self.outputs, feed_dict=feed_dict)
            if i_start==0:
                sum_results = results
            else:
                for i in range(len(results)):
                    sum_results[i] = sum_results[i] + results[i]
        for i in range(len(results)):
            sum_results[i] = sum_results[i] / n
        return sum_results

# ================================================================
# Modules
# ================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号