def process_state(self, state):
"""
Processing of state
State placeholders are tf.uint8 for fast transfer to GPU
Need to cast it to float32 for the rest of the tf graph.
Args:
state: node of tf graph of shape = (batch_size, height, width, nchannels)
of type tf.uint8.
if , values are between 0 and 255 -> 0 and 1
"""
state = tf.cast(state, tf.float32)
state /= self.config.high
return state
评论列表
文章目录