complex_network.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorforce 作者: reinforceio 项目源码 文件源码
def tf_apply(self, x, update):
        inputs_to_merge = list()
        for name in self.inputs:
            # Previous input, by name or "*", like normal network_spec
            # Not using named_tensors as there could be unintended outcome 
            if name == "*" or name == "previous":  
                inputs_to_merge.append(x)
            elif name in self.named_tensors:
                inputs_to_merge.append(self.named_tensors[name])
            else:
                # Failed to find key in available inputs, print out help to user, raise error
                keys=list(self.named_tensors)
                raise TensorForceError(
                    'ComplexNetwork input "{}" doesn\'t exist, Available inputs: {}'.format(name,keys)
                )    
        # Review data for casting to more precise format so TensorFlow doesn't throw error for mixed data
        # Quick & Dirty cast only promote types: bool=0,int32=10, int64=20, float32=30, double=40

        cast_type_level = 0
        cast_type_dict = {
            'bool':0,
            'int32':10,
            'int64':20,
            'float32':30,
            'float64':40
        }
        cast_type_func_dict = {
            0:tf.identity,
            10:tf.to_int32,
            20:tf.to_int64,
            30:tf.to_float,
            40:tf.to_double
        }
        # Scan inputs for max cast_type            
        for tensor in inputs_to_merge:
            key = str(tensor.dtype.name)
            if key in cast_type_dict:
                if cast_type_dict[key] > cast_type_level:
                    cast_type_level = cast_type_dict[key]
            else:
                raise TensorForceError('Network spec input does not support dtype {}'.format(key))

        # Add casting if needed
        for index, tensor in enumerate(inputs_to_merge):
            key = str(tensor.dtype.name)

            if cast_type_dict[key] < cast_type_level:
                inputs_to_merge[index]=cast_type_func_dict[cast_type_level](tensor)

        input_tensor = tf.concat(inputs_to_merge, self.axis)

        return input_tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号