def tf_apply(self, x, update):
inputs_to_merge = list()
for name in self.inputs:
# Previous input, by name or "*", like normal network_spec
# Not using named_tensors as there could be unintended outcome
if name == "*" or name == "previous":
inputs_to_merge.append(x)
elif name in self.named_tensors:
inputs_to_merge.append(self.named_tensors[name])
else:
# Failed to find key in available inputs, print out help to user, raise error
keys=list(self.named_tensors)
raise TensorForceError(
'ComplexNetwork input "{}" doesn\'t exist, Available inputs: {}'.format(name,keys)
)
# Review data for casting to more precise format so TensorFlow doesn't throw error for mixed data
# Quick & Dirty cast only promote types: bool=0,int32=10, int64=20, float32=30, double=40
cast_type_level = 0
cast_type_dict = {
'bool':0,
'int32':10,
'int64':20,
'float32':30,
'float64':40
}
cast_type_func_dict = {
0:tf.identity,
10:tf.to_int32,
20:tf.to_int64,
30:tf.to_float,
40:tf.to_double
}
# Scan inputs for max cast_type
for tensor in inputs_to_merge:
key = str(tensor.dtype.name)
if key in cast_type_dict:
if cast_type_dict[key] > cast_type_level:
cast_type_level = cast_type_dict[key]
else:
raise TensorForceError('Network spec input does not support dtype {}'.format(key))
# Add casting if needed
for index, tensor in enumerate(inputs_to_merge):
key = str(tensor.dtype.name)
if cast_type_dict[key] < cast_type_level:
inputs_to_merge[index]=cast_type_func_dict[cast_type_level](tensor)
input_tensor = tf.concat(inputs_to_merge, self.axis)
return input_tensor
评论列表
文章目录